CV Week: Итоговое задание¶
На лекции и семинаре мы разбирали как дистиллировать многошаговую диффузионную модель в малошагового студента, и тем самым будет работать на порядок быстрее учителя.
Один из подходов, который мы разбирали Consistency Distillation. В этом задании, мы закрепим материал, который был на лекции и семинаре и реализуем этот фреймворк, затрагивая различные нюансы.
В этом задании мы будем дистиллировать модель Stable Diffusion 1.5 (SD1.5) для генерации картинок по текстовому описанию.
Вам предстоит выполнить 8 небольших заданий, которые приведут нас к неплохой модели для генерации картинок за 4 шага, работая в органиченных условиях колаба.
# torch 2.4.1+cu124
# !pip install diffusers==0.30.3, peft==0.8.2 huggingface_hub==0.23.4
Теормин¶
Диффузионные модели¶
Задан прямой диффузионный процесс, который переводит чистые картинки в шум с помощью распределения $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$
Таким образом, мы можем получаться зашумленные картинки по следующей формуле: $\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$ (1)
$\alpha_t, \sigma_t$ задают процесс зашумления. Здесь мы будем иметь дело с variance preserving (VP) процессом, т. е., $\alpha^2_t = 1 - \sigma^2_t$.
Диффузионная модель (ДМ) пытается решить обратную задачу: из шума порождать новые картинки. Важно, что диффузионный процесс можно описать следующим обыкновенным дифференциальным уравнением (ОДУ):
$dx = \left[ f(\mathbf{x}, t) - \frac{1}{2} \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}) \right] dt$, (2)
где $f(\mathbf{x}, t)$ известен из заданного процесса зашумления, а $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$ (скор функцию) оцениваем с помощью нейросети: $s_\theta(\mathbf{x}_t, t) \approx \nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t)$. Таким образом, имея оценку на $\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x})$, мы можем решить это ОДУ, стартуя со случайного шума, и получить картинку.
SD1.5 использует $\epsilon$-параметризацию, т.е., UNet пытается предсказать шум, который мы добавили на картинку по формуле (1). Оценку скор функции можно получить, пользуясь результатом, вытекающим из формулы Твидди: $s_\theta(\mathbf{x}_t, t) = - \frac{\epsilon_\theta(\mathbf{x}_t, t)} { \sigma_t}$
Чтобы решить ОДУ (2), нам нужно воспользоваться каким-то численным методом (солвером). В этом задании мы будем работать с не самым эффектным, но самым популярным солвером: DDIM, который является адаптированным методом Эйлера под диффузионный ОДУ.
Для VP процесса переход с помощью DDIM с шага $t$ на $s$ можно сделать следующим образом:
$ x_s = DDIM(\mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $
Этот переход можно интерпретировать так: получаем оценку на чистую картинку $\mathbf{x}_0$ на шаге $t$, используя $\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t}$, а потом снова зашумляем эту оценку на шаг $s$ по формуле (1), но только используем не случайный шум, а шум предсказанный моделью $\epsilon_\theta$.
Используя DDIM для SD1.5, можем получать хорошие картинки за 50 шагов.
SD1.5 - латентная ДМ, т.е. модель работает не в пиксельном пространстве, а в латентном пространстве VAE. Таким образом SD1.5 состоит из следующих компонент:
- VAE - переводит $3{\times}512{\times}512$ картинки в латенты $4{\times}64{\times}64$ и может декодировать их обратно в картинки.
- Текстовый энкодер - извлекает текстовые признаки из промпта. Эти признаки будут подаваться в диффузионную модель, чтобы дать модели информацию, что именно хотим сгенерировать
- Диффузионная модель - UNet, работающий на "латентных картинках" $4{\times}64{\times}64$.
Консистенси модели¶
Общая идея¶
Главная цель дистилляции диффузии - уменьшить количество шагов ДМ, при этом сохранив высокое качество картинок.
Консистенси модели (Consistency Models | CM) - класс моделей, где мы хотим выучить "консистенси функцию" $f_\theta(\mathbf{x}_t)$ - с любой точки $\mathbf{x}_{t}$ траектории диффузионного ОДУ (2) сразу предсказывать $\mathbf{x}_{0}$ (чистые данные) за один шаг. Если мы идеально выучим консистенси функцию, то сможем шагать из чистого шума сразу в картинку, что супер эффективно в отличии от генерации ДМ.
Отметим, что консистенси модель можно учить как независимую генеративную модель, без предобученной ДМ, и в задании 3 вам предстоит подумать, как это можно сделать.

Консистенси дистилляция (Consistency Distillation | CD) - подход, когда для обучения CM, мы используем предобученную ДМ. ДМ нам дает качественную инициализацию модели и уже обученную скор функцию, что сильно упрощает сходимость консистенси моделей.
Обучение CM¶
Главная принцип обучения консистенси моделей заключается в попытке удовлетворить self-consistency св-ву: выход CM на двух соседних точках траектории $\mathbf{x}_{t}$ и $\mathbf{x}_{t-1}$ должен совпадать по какой-то мере близости, например L2 расстояние: $\lVert f_\theta(\mathbf{x}_{t-1}) - f_\theta(\mathbf{x}_{t}) \rVert^2_2$.
Заметим, что self-consistency св-во удовлетворить очень просто без какого-либо обучения, взяв, например $f_\theta(\mathbf{x}_{t}) \equiv 0$.
Поэтому, чтобы избежать вырожденных решений, нам необходимо выставить граничное условие (boundary condition), которое будет требовать, чтобы в самой левой точке траектории около 0, модель предсказывала картинку, которую получает на вход: $f_\theta(\mathbf{x}_{\epsilon}) = \mathbf{x}_{\epsilon}$.
Практическое замечание: Для обеих точек траектории мы применяем одну и ту же модель $f_\theta(\cdot)$. Но выход модели на шаге ${t-1}$ является "таргетом" для выхода модели на шаге $t$ и поэтому выполнение модели для шага $t-1$ выполняется в torch.no_grad режиме.
Как получаться две соседние точки на траектории ОДУ?
Берем случайную картинку $\mathbf{x}_0$ из датасета.
Точку $\mathbf{x}_t$ получаем с помощью прямого процесса зашумления: $\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$
Чтобы получить соседнюю точку $\mathbf{x}_{t-1}$, нам нужно сделать шаг по траектории ОДУ, используя, например, DDIM солвер.
В консистенси дистилляции, мы делаем шаг предобученной ДМ: $\mathbf{x}_{t-1} = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, t-1)$
Важно: на практике мы можем брать не соседние шаги $t$ и $t-1$, а с некоторым интервалом, например 20 шагов. Размер интервала влияет на bias/variance trade-off в консистенси обучении: больше интервал между шагами - больше смещение, но меньше дисперсия, и наоборот. Для простоты в этом задании мы зафиксируем интервал - 20 шагов, но во многих работах размер интервала динамически меняют по ходу обучения.
from tqdm.auto import tqdm
import csv
import os
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, LCMScheduler, UNet2DConditionModel, DDIMScheduler
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
%matplotlib inline
import matplotlib.pyplot as plt
2024-12-09 22:53:22.236387: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-12-09 22:53:22.236469: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-12-09 22:53:22.302624: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-12-09 22:53:22.442945: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-12-09 22:53:23.908561: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
# ---------------------
# Visualization utils
# ---------------------
def visualize_images(images):
assert len(images) == 4
plt.figure(figsize=(12, 3))
for i, image in enumerate(images):
plt.subplot(1, 4, i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplots_adjust(wspace=-0.01, hspace=-0.01)
# --------------
# Tensor utils
# --------------
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
# ---------------
# Dataset utils
# ---------------
class COCODataset(torch.utils.data.Dataset):
def __init__(self, root_dir, subset_name="train2014_5k", transform=None, max_cnt=None):
"""
Arguments:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.transform = transform
self.extensions = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
sample_dir = os.path.join(root_dir, subset_name)
# Collect sample paths
self.samples = sorted(
[
os.path.join(sample_dir, fname)
for fname in os.listdir(sample_dir)
if fname[-4:] in self.extensions
],
key=lambda x: x.split("/")[-1].split(".")[0],
)
self.samples = (
self.samples if max_cnt is None else self.samples[:max_cnt]
) # restrict num samples
# Collect captions
self.captions = {}
with open(os.path.join(root_dir, f"{subset_name}.csv"), newline="\n") as csvfile:
spamreader = csv.reader(csvfile, delimiter=",")
for i, row in enumerate(spamreader):
if i == 0:
continue
self.captions[row[1]] = row[2]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample_path = self.samples[idx]
sample = Image.open(sample_path).convert("RGB")
if self.transform:
sample = self.transform(sample)
return {
"image": sample,
"text": self.captions[os.path.basename(sample_path)],
"idxs": idx,
}
Модель учителя (SD1.5)¶
Задание №1¶
Давайте для начала загрузим модель StableDiffusion 1.5 и сгенерируем ей картинки за 50 шагов.
Важно: для экономии памяти, загружаем все компоненты модели в FP16. Не забываем положить модель на GPU.
model_id = "sd-legacy/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
variant="fp16",
)
pipe = pipe.to("cuda")
# Проверяем, что все компоненты модели в FP16 и на cuda
assert pipe.unet.dtype == torch.float16 and pipe.unet.device.type == "cuda"
assert pipe.vae.dtype == torch.float16 and pipe.vae.device.type == "cuda"
assert pipe.text_encoder.dtype == torch.float16 and pipe.text_encoder.device.type == "cuda"
# Заменяем дефолтный сэмплер на DDIM
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config,
timestep_spacing="trailing",
)
pipe.scheduler.timesteps = pipe.scheduler.timesteps.cuda()
pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.cuda()
# Отдельно извлечем модель учителя, которую потом будем дистиллировать
teacher_unet = pipe.unet
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
Теперь сгенерируем картинки за 50 шагов. Вам нужно написать вызов pipe и передать в него промпт, число шагов генерации, генератор случайных чисел, гайденс скейл и указать, чтобы сгенерировалось 4 картинки на промпт.
prompt = "A sad puppy with large eyes"
guidance_scale = 7.5
generator = torch.Generator("cuda").manual_seed(1)
# generate 4 images
images = pipe(
prompt,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=4,
num_inference_steps=50,
).images # type: ignore
visualize_images(images)
0%| | 0/50 [00:00<?, ?it/s]
Давайте посмотрим, что выдаст модель за 4 шага. Все то же самое, что и выше, просто поменяем число шагов.
generator = torch.Generator("cuda").manual_seed(1)
images = pipe(
prompt,
# guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=4,
num_inference_steps=4,
).images # type: ignore
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
На 4 шагах картинки получаются размазанными. Давайте постараемся починить их.
Создаем датасет¶
Чтобы ДЗ было легко выполнимым на colab, мы будем учить консистенси модели на небольшой обучающей выборке из 5000 пар текст-картинка из COCO датасета. Интересное свойство консистенси моделей - они могут сходиться до адекватного качества за несколько сотен шагов. Качество все еще будет не идеальным, но фазовый переход уже должен быть заметен.
Данные можно загрузить с помощью команд в ячейке ниже. В локальной текущей директории ./ должны появиться:
- Папка train2014_5k с 5000 картинками
- Файл train2014_5k.csv с 5000 промптами
Данные парсятся корректным образом в уже реализованном классе COCODataset.
# !wget https://storage.yandexcloud.net/yandex-research/train2014_5k.tar.gz
# !tar -xzf train2014_5k.tar.gz
Замечание: для более быстрого дебаггинга можете взять, например, 2500 картинок и прогнать на всей выборке только в самом конце. 2500 картинок должно быть достаточно для понимания корректно ли реализованы функции. Совсем для первичного дебаггинга можно взять еще меньше картинок.
from torchvision import transforms
transform = transforms.Compose(
[
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
lambda x: 2 * x - 1,
]
)
dataset = COCODataset(
".",
subset_name="train2014_5k",
transform=transform,
# max_cnt=2500,
)
# assert len(dataset) == 2500 # 2500
assert len(dataset) == 5000
batch_size = 8 # Рекоммендуемы размер батча на Colab
train_dataloader = torch.utils.data.DataLoader(
dataset=dataset, shuffle=True, batch_size=batch_size, drop_last=True
)
@torch.no_grad()
def prepare_batch(batch, pipe):
"""
Предобработка батча картинок и текстовых промптов.
Маппим картинки в латентное пространство VAE.
Извлекаем эмбеды промптов с помощью текстового энкодера.
Params:
Return:
latents: torch.Tensor([B, 4, 64, 64], dtype=torch.float16)
prompt_embeds: torch.Tensor([B, 77, D], dtype=torch.float16)
"""
# Токенизируем промпты
text_inputs = pipe.tokenizer(
batch["text"],
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
# Извлекаем эмбеды промптов с помощью текстового энкодера
prompt_embeds = pipe.text_encoder(text_inputs.input_ids.cuda())[0]
# Переводим картинки в латентное пространство VAE
image = batch["image"].to("cuda", dtype=torch.float16)
latents = pipe.vae.encode(image).latent_dist.sample()
latents = latents * pipe.vae.config.scaling_factor
return latents, prompt_embeds
Подготовка моделей и оптимизатора¶
Для начала создаем обучаемую модель: UNet инициализируемый весами SD1.5. Вам нужно воспользоваться классом UNet2DConditionModel и загрузить отдельно только UNet модель из SD1.5.
Отметим, что эта модель у нас будет храниться в полной точности FP32, потому что обучение параметров в FP16 может приводить к нестабильностям и низкому качеству.
unet = UNet2DConditionModel.from_pretrained(
"sd-legacy/stable-diffusion-v1-5",
subfolder="unet",
device_map="balanced",
)
unet.train()
assert unet.dtype == torch.float32
assert unet.training
Для экономии памяти во время обучения будем учить не параметры самой модели, а добавим в нее обучаемые LoRA адаптеры с малым числом параметров.
LoRA представляет собой маленькую добавку к весам модели, где на одну матрицу весов $W \in \mathbb{R}^{m{\times}n} $ обучаются две низкоранговые матрицы $W_A \in \mathbb{R}^{k{\times}n}$ и $W_B \in \mathbb{R}^{k{\times}m}$, где $k$ - ранг матрицы сильно меньше $m$ и $n$.
Тем самым, новая обученная матрица весов может быть представлена как $\hat{W} = W + \Delta W = W + W^T_B W_A$.
Во время инференса $\Delta W$ можно вмержить в $W$ и получить итоговую модель.
Также частая практика оставлять адаптеры как есть, чтобы была возможность для одной базовой модели учить несколько адаптеров под разные задачи и переключаться между ними по необходимости.
Если не мержить адаптеры, то вычисления для линейного слоя происходят как на картинке ниже.
# Указываем к каким слоям модели мы будет добавлять адаптеры.
lora_modules = [
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
]
lora_config = LoraConfig(r=64, target_modules=lora_modules) # задает ранг у матриц A и B в LoRA.
# Создаем обертку исходной UNet модели с LoRA адаптерами, используя библиотеку PEFT
cm_unet = get_peft_model(unet, lora_config, adapter_name="ct")
# Включаем gradient checkpointing - важная техника для экономии памяти во время обучения
cm_unet.enable_gradient_checkpointing()
# Создаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
# Задаем лосс функцию для CM обжектива. В базовом варианте разумно взять L2
# По умолчанию, она уже выдает усредненное значение по всем размерностям
mse_loss = torch.nn.functional.mse_loss
Задание №2 (0.5 балла, сдается в контесте)¶
Реализация шага DDIM¶
Шаг с помощью DDIM с $\mathbf{x}_t$ на $\mathbf{x}_s$ можно сделать следующим образом:
$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta $
Вам нужно реализовать эту формулу в уже готовом шаблоне ниже. Чтобы корректно выполнить задание, вам нужно задать $\alpha_t$ и $\sigma_t$ имея DDIMScheduler. **Обратите внимание на аттрибут *scheduler.alphas_cumprod***, который задает $\bar\alpha_{t} = \prod^t_{i=1} (1-\beta_i)$ в классической DDPM формулировке: Denoising Diffusion Probabilistic Models.
def ddim_solver_step(model_output, x_t, t, s, scheduler): # -> Any:
"""
Шаг DDIM солвера для VP процесса зашумления и eps-prediction модели
params:
model_output: torch.Tensor[B, 4, 64, 64] - предсказание модели - шум eps
x_t: torch.Tensor[B, 4, 64, 64] - сэмплы на шаге t
t: torch.Tensor[B] - номер текущего шага
s: torch.Tensor[B] - номер следующего шага
scheduler: DDIMScheduler - расписание диффузионного процесса, чтобы получить alpha и sigma
"""
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)
sigmas_s = extract_into_tensor(sigmas, s, x_t.shape)
alphas_s = extract_into_tensor(alphas, s, x_t.shape)
sigmas_t = extract_into_tensor(sigmas, t, x_t.shape)
alphas_t = extract_into_tensor(alphas, t, x_t.shape)
# Выставляем крайние значения alpha и sigma, чтобы выполнялись граничные условия
alphas_s[s == 0] = 1.0
sigmas_s[s == 0] = 0.0
alphas_t[t == 0] = 1.0
sigmas_t[t == 0] = 0.0
# Reverse diffusion formula
x_0 = (x_t - sigmas_t * model_output) / alphas_t
# DDIM formula
x_s = alphas_s * x_0 + sigmas_s * model_output
return x_s
Реализация процесса зашумления (q sample)¶
Аналогично, нам нужен процесс зашумления $q(\mathbf{x}_t | \mathbf{x}_0)= {N}(\mathbf{x}_t | \alpha_t \mathbf{x}_0, \sigma^2_t I)$
$\mathbf{x}_t = \alpha_t \mathbf{x}_0 + \sigma_t \epsilon$, где $\epsilon{\sim} {N}(0, I)$
def q_sample(x, t, scheduler, noise=None):
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)
if noise is None:
noise = torch.randn_like(x)
sigmas_t = extract_into_tensor(sigmas, t, x.shape)
alphas_t = extract_into_tensor(alphas, t, x.shape)
x_t = x * alphas_t + sigmas_t * noise
return x_t
Consistency Training¶
Обучение консистенси моделей без учителя называется Consistency Training (CT). В таком случае CM можно рассматривать как отдельный вид генеративных моделей. Давайте начнем именно с этого подхода и обучим нашу первую консистенси модель на базе SD1.5.
Задание №3¶
Задание №3.1 (0.5 балла, сдается в контесте)¶
В консиcтенси дистилляции модель учителя используется для получения второй точки на траектории ODE. Можем ли мы попробовать оценить соседнюю точку аналитически?
Вам предлагается вывести это самим, используя формулу DDIM шага выше и вспомнив, как мы оцениваем скор функции в denoising score matching-e:
$\epsilon_\theta(x_t, t) = - \sigma_t s_\theta(x_t, t)$
$s_\theta(x_t, t) \approx \nabla_{x_t} \log q(x_t) = \mathop{\mathbb{E}}_{\mathbf{x}\sim p_{data}}\left [ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t | \mathbf{x}) \vert \mathbf{x}_t \right ] \approx \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})$
< YOUR DERIVATION HERE > $$ \mathbf{x}_s = DDIM(\epsilon_\theta, \mathbf{x}_t, t, s) = \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t \epsilon_\theta}{\alpha_t} \right) + \sigma_s \epsilon_\theta = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t + \sigma_t^2 s_\theta(x_t, t)}{\alpha_t} \right) - \sigma_s \sigma_t s_\theta(x_t, t) = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t + \sigma_t^2 \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x})}{\alpha_t} \right) - \sigma_s \sigma_t \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t \vert \mathbf{x}) = \\ \alpha_s \cdot \left(\frac{\mathbf{x}_t - \sigma_t^2 \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t^2}\right)}{\alpha_t} \right) + \sigma_s \sigma_t \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t^2}\right) = \\ \alpha_s \mathbf{x}_0 + \sigma_s \cdot \left(\frac{\mathbf{x}_t - \alpha_t \mathbf{x}_0}{\sigma_t}\right) $$
Если возникнут трудность, можно обратиться к оригинальной статье.
Теперь реализуем то, что у вас получилось в функции ниже.
def get_xs_from_xt_naive(x_0, x_t, t, s, scheduler, noise=None, **kwargs):
"""
Получение точки x_s в CT режиме, т.е., аналитически.
"""
if x_0 is None:
x_0 = torch.zeros_like(x_t)
if x_t is None:
x_t = q_sample(x_0, t, scheduler, noise=noise)
if (t == s).all():
return x_t
alphas = torch.sqrt(scheduler.alphas_cumprod)
sigmas = torch.sqrt(1.0 - scheduler.alphas_cumprod)
alpha_t = extract_into_tensor(alphas, t, x_t.shape)
sigma_t = extract_into_tensor(sigmas, t, x_t.shape)
alpha_s = extract_into_tensor(alphas, s, x_t.shape)
sigma_s = extract_into_tensor(sigmas, s, x_t.shape)
alpha_t[t == 0] = 1.0
sigma_t[t == 0] = 0.0
x_s = x_0.clone().detach()
cond_1 = (sigma_t != 0) * (sigma_s != 0)
x_s = torch.where(
cond_1,
alpha_s * x_0 + sigma_s * (x_t - alpha_t * x_0) / sigma_t,
x_s,
)
cond_2 = (sigma_t == 0) * (sigma_s != 0)
x_s = torch.where(
cond_2,
q_sample(x_0, s, scheduler, noise),
x_s,
)
return x_s
Задание №3.2¶
Ниже представлен шаблон функции, которая считает лосс для консистенси моделей. Вам нужно правильно заполнить пропуски, чтобы получилась корректная функция.
def cm_loss_template(
latents,
prompt_embeds, # батч латентов и текстовых эмбедов
unet,
scheduler,
# Функции, которые будем постепенно менять из задания к заданию
loss_fn: callable,
get_boundary_timesteps: callable,
get_xs_from_xt: callable,
num_timesteps=1000,
step_size=20, # Указываем с каким интервалом берем шаги s и t.
):
# Сэмплируем случайные шаги t для каждого элемента батча t ~ U[step_size-1, 999]
assert num_timesteps == 1000
num_intervals = num_timesteps // step_size
index = torch.randint(
1, num_intervals, (len(latents),), device=latents.device
).long() # [1, num_intervals]
t = step_size * index - 1
s = torch.clamp(t - step_size, min=0)
boundary_timesteps = get_boundary_timesteps(s, num_timesteps=num_timesteps)
# Сэмплируем x_t
noise = torch.randn_like(latents)
x_t = q_sample(latents, t, scheduler, noise=noise)
# with <YOUR CODE HERE>: # для реализации mixed-precision обучения в задании №4
with torch.cuda.amp.autocast(dtype=torch.float16):
noise_pred = unet(
x_t.float(),
t,
encoder_hidden_states=prompt_embeds.float(),
).sample
# Получаем оценку в граничной точке для x_t
boundary_pred = ddim_solver_step(noise_pred, x_t, t, boundary_timesteps, scheduler)
x_s = get_xs_from_xt(
latents,
x_t,
t,
s,
scheduler,
prompt_embeds=prompt_embeds,
noise=noise,
)
# Предсказание "таргет моделью"
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
target_noise_pred = unet(x_s, s, encoder_hidden_states=prompt_embeds).sample
# Получаем оценку в граничной точке для x_s
boundary_target = ddim_solver_step(target_noise_pred, x_s, s, boundary_timesteps, scheduler)
loss = loss_fn(boundary_pred, boundary_target)
return loss
import functools
def get_zero_boundary_timesteps(t, **kwargs):
"""
Определяем шаги где будут срабатывать граничные условия.
Для классических СM это t=0.
"""
return torch.zeros_like(t)
ct_loss = functools.partial(
cm_loss_template,
loss_fn=mse_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_naive,
)
assert cm_unet.active_adapter == "ct"
Задание №4¶
Эффективное обучение¶
Данное задание рассчитано на успешное выполнение на colab с бесплатной Tesla T4 c 15GB VRAM. Однако учить даже относительно небольшие T2I модели масштаба SD1.5 уже на коллабе в лоб проблематично.
Для этого нам нужно применить ряд инженерных техник, чтобы уместиться в данный бюджет и учиться за разумное время.
Список техник
- Включить gradient checkpointing для обучемой модели
- Добавить LoRA (Low Rank Adapters) адаптеры, чтобы учить не все веса, а только 10% добавочных весов
- Использовать gradient accumulation, чтобы делать итерацию обучения по бОльшему батчу, чем влезает по памяти
- Добавить mixed precision FP16/FP32 обучение модели для скорости. Обычно еще и память экономится, но в случае LoRA обучения + gradient checkpointing на память сильно влиять не должно, но зато станет быстрее.
- Мульти-GPU обучение - распределение вычислений по нескольким GPU.
1-2) Мы уже применили за вас выше
3-4) Предстоит реализовать вам самим в соотвествующей секции ниже
5 ) Недоступно, так как работаем на одной карточке
Обучающий цикл¶
Вам дан код обучения модель в полной точности (FP32) c батчом 8. К сожалению, на Tesla T4 мы не влезем по памяти. Поэтому в ячейке ниже вам нужно модифицировать цикл, чтобы он работал в mixed precision FP16 и добавить gradient accumulation.
Про реализацию mixed-precision в pytorch можно перейти по ссылке: Mixed-precision обучение
Обратите внимание: вам еще нужно добавить одну строчку кода в cm_loss_template в соответствующем плейсхолдере.
Замечание: В начале обучения значения лосса должны быть в окрестности 0.0007-0.001. Ничего страшного, что лосс не падает, для CM это нормально. В конце обучения лосс может доходить до 0.005-0.01
def train_loop(model, pipe, train_dataloader, optimizer, loss_fn, num_grad_accum=1):
torch.cuda.empty_cache()
scaler = torch.cuda.amp.GradScaler()
for i, batch in enumerate(tqdm(train_dataloader)):
with torch.cuda.amp.autocast(dtype=torch.float16):
latents, prompt_embeds = prepare_batch(batch, pipe)
loss = loss_fn(latents, prompt_embeds, model, pipe.scheduler) / num_grad_accum
# Обновляем параметры
scaler.scale(loss).backward()
if (i + 1) % num_grad_accum == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(f"Loss: {loss.detach().item()}")
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, ct_loss, num_grad_accum)
0%| | 0/625 [00:00<?, ?it/s]
Loss: 0.00037942116614431143 Loss: 0.0005064992001280189 Loss: 0.00041276204865425825 Loss: 0.0008958675898611546 Loss: 0.00046880001900717616 Loss: 0.0005985702155157924 Loss: 0.0004536816559266299 Loss: 0.00042376824421808124 Loss: 0.0007258296245709062 Loss: 0.00047782735782675445 Loss: 0.0005945760058239102 Loss: 0.0005281904013827443 Loss: 0.001687874086201191 Loss: 0.0008240551687777042 Loss: 0.0008860052330419421 Loss: 0.0008627125062048435 Loss: 0.0004061653162352741 Loss: 0.0004743822501040995 Loss: 0.0008747418178245425 Loss: 0.0005310914712026715 Loss: 0.0005723602953366935 Loss: 0.0008468653541058302 Loss: 0.0004579655360430479 Loss: 0.0005488618044182658 Loss: 0.0005751841817982495 Loss: 0.0004992288304492831 Loss: 0.0007894040900282562 Loss: 0.0007330112857744098 Loss: 0.0008934878278523684 Loss: 0.00043908506631851196 Loss: 0.000619879923760891 Loss: 0.0004753718967549503 Loss: 0.0006294205086305737 Loss: 0.0007459073094651103 Loss: 0.0005154737154953182 Loss: 0.0004189095343463123 Loss: 0.0014230990782380104 Loss: 0.0007570700254291296 Loss: 0.0007128705619834363 Loss: 0.0004862792557105422 Loss: 0.0008252396946772933 Loss: 0.000948379107285291 Loss: 0.0003448168863542378 Loss: 0.0014016738859936595 Loss: 0.00043688761070370674 Loss: 0.000760576338507235 Loss: 0.0004906122339889407 Loss: 0.001050944672897458 Loss: 0.00034305930603295565 Loss: 0.000333520642016083 Loss: 0.001346747623756528 Loss: 0.000518926652148366 Loss: 0.00038751529064029455 Loss: 0.0007347655482590199 Loss: 0.0006552053382620215 Loss: 0.0004693095979746431 Loss: 0.0005048643797636032 Loss: 0.0008245222270488739 Loss: 0.0006888275966048241 Loss: 0.00044082454405725 Loss: 0.0005636227433569729 Loss: 0.0005252123810350895 Loss: 0.0007635715301148593 Loss: 0.0005486704758368433 Loss: 0.0009442295995540917 Loss: 0.0006193040171638131 Loss: 0.0013048275141045451 Loss: 0.00045766972471028566 Loss: 0.0004775456618517637 Loss: 0.0005675374995917082 Loss: 0.004480988718569279 Loss: 0.001354208099655807 Loss: 0.0023091307375580072 Loss: 0.0030480027198791504 Loss: 0.0013057851465418935 Loss: 0.0008849604055285454 Loss: 0.0009611251298338175 Loss: 0.0006008940399624407 Loss: 0.0009718112414702773 Loss: 0.0006195436581037939 Loss: 0.0004468054394237697 Loss: 0.0007318559219129384 Loss: 0.000600254803430289 Loss: 0.0007269569323398173 Loss: 0.0008202603203244507 Loss: 0.0009326831204816699 Loss: 0.0010233625071123242 Loss: 0.0016146933194249868 Loss: 0.0006531896069645882 Loss: 0.0011280549224466085 Loss: 0.0010213699424639344 Loss: 0.0008086289744824171 Loss: 0.0008692542323842645 Loss: 0.0005729414988309145 Loss: 0.0006522737094201148 Loss: 0.0012347043957561255 Loss: 0.0014926702715456486 Loss: 0.0014064067509025335 Loss: 0.0016721924766898155 Loss: 0.0010140687227249146 Loss: 0.0007766926428303123 Loss: 0.0007161159301176667 Loss: 0.0016449993709102273 Loss: 0.0016016623703762889 Loss: 0.0009757575462572277 Loss: 0.0010558852227404714 Loss: 0.002975206822156906 Loss: 0.012377133592963219 Loss: 0.0015951453242450953 Loss: 0.0025169397704303265 Loss: 0.0017778994515538216 Loss: 0.0009804833680391312 Loss: 0.0023517082445323467 Loss: 0.0028763553127646446 Loss: 0.003936590161174536 Loss: 0.001311366562731564 Loss: 0.001428470597602427 Loss: 0.0016911597922444344 Loss: 0.0011702944757416844 Loss: 0.0014147096080705523 Loss: 0.002440669108182192 Loss: 0.0013251928612589836 Loss: 0.004199718590825796 Loss: 0.0011574899544939399 Loss: 0.0014455909840762615 Loss: 0.006965372711420059 Loss: 0.0014533543726429343 Loss: 0.0020802379585802555 Loss: 0.005591132678091526 Loss: 0.0007086016703397036 Loss: 0.0014595562824979424 Loss: 0.003932422958314419 Loss: 0.001268944120965898 Loss: 0.001382922986522317 Loss: 0.003330608131363988 Loss: 0.004483602941036224 Loss: 0.003217090852558613 Loss: 0.0020847718697041273 Loss: 0.0023813711013644934 Loss: 0.0023249583318829536 Loss: 0.0025102400686591864 Loss: 0.0014276099391281605 Loss: 0.0010522021912038326 Loss: 0.0021666488610208035 Loss: 0.0010541814845055342 Loss: 0.0013616065261885524 Loss: 0.00223144399933517 Loss: 0.0017787872347980738 Loss: 0.001991346012800932 Loss: 0.0024860501289367676 Loss: 0.0017141818534582853 Loss: 0.001713483827188611 Loss: 0.002013132907450199 Loss: 0.0016368308570235968 Loss: 0.0020724921487271786 Loss: 0.001502691418863833 Loss: 0.0018795530777424574 Loss: 0.0007520514191128314 Loss: 0.0007461420027539134 Loss: 0.0022452983539551497 Loss: 0.002604313427582383 Loss: 0.0008028405718505383 Loss: 0.0025118214543908834 Loss: 0.0013362554600462317 Loss: 0.0009904057951644063 Loss: 0.0026333308778703213 Loss: 0.0013375040143728256 Loss: 0.0013061071513220668 Loss: 0.0012211732100695372 Loss: 0.0016559758223593235 Loss: 0.0007818201556801796 Loss: 0.0007729750941507518 Loss: 0.0008355433237738907 Loss: 0.0006848564371466637 Loss: 0.0009499641600996256 Loss: 0.0007108630961738527 Loss: 0.0011214565020054579 Loss: 0.00048189060180447996 Loss: 0.0007489272393286228 Loss: 0.000885029265191406 Loss: 0.0008287807577289641 Loss: 0.000563493580557406 Loss: 0.0005383663810789585 Loss: 0.0023214269895106554 Loss: 0.0020749422255903482 Loss: 0.0008023543050512671 Loss: 0.0017955926014110446 Loss: 0.000757447094656527 Loss: 0.0005793230957351625 Loss: 0.0006887734634801745 Loss: 0.0009876531548798084 Loss: 0.0005812561721540987 Loss: 0.00046745891449972987 Loss: 0.0010823803022503853 Loss: 0.0011851361487060785 Loss: 0.0006656574551016092 Loss: 0.0006380220875144005 Loss: 0.0005546664469875395 Loss: 0.0007966226548887789 Loss: 0.0006024678586982191 Loss: 0.0006565562216565013 Loss: 0.0010175753850489855 Loss: 0.001282352488487959 Loss: 0.0007072788430377841 Loss: 0.0015107081271708012 Loss: 0.0012874935055151582 Loss: 0.0008330004056915641 Loss: 0.000535622937604785 Loss: 0.0006450955988839269 Loss: 0.0006074419361539185 Loss: 0.0007074175518937409 Loss: 0.0007562537211924791 Loss: 0.0009953570552170277 Loss: 0.0014050425961613655 Loss: 0.0004630361800082028 Loss: 0.0011367471888661385 Loss: 0.0018684373935684562 Loss: 0.0012699714861810207 Loss: 0.00047859508777037263 Loss: 0.0009107966325245798 Loss: 0.0013107287231832743 Loss: 0.001922330935485661 Loss: 0.0014184056781232357 Loss: 0.0007522629457525909 Loss: 0.000422043027356267 Loss: 0.0010425536893308163 Loss: 0.0011007111752405763 Loss: 0.0011708419770002365 Loss: 0.0011846733978018165 Loss: 0.0006315461359918118 Loss: 0.0012706996640190482 Loss: 0.0014523772988468409 Loss: 0.0006138435564935207 Loss: 0.0017626138869673014 Loss: 0.0005771452561020851 Loss: 0.0010470845736563206 Loss: 0.0020099482499063015 Loss: 0.0007128794677555561 Loss: 0.0008252130355685949 Loss: 0.001020087394863367 Loss: 0.0009030108922161162 Loss: 0.0007460082415491343 Loss: 0.0006069971714168787 Loss: 0.0012493666727095842 Loss: 0.0009998545283451676 Loss: 0.0005000063101761043 Loss: 0.0013536994811147451 Loss: 0.0009585645166225731 Loss: 0.0008933030185289681 Loss: 0.0005902259144932032 Loss: 0.0023559462279081345 Loss: 0.0007550917216576636 Loss: 0.0013092129956930876 Loss: 0.0005594904650934041 Loss: 0.0008394026081077754 Loss: 0.0014222621684893966 Loss: 0.0010701077990233898 Loss: 0.0006709058070555329 Loss: 0.0014043166302144527 Loss: 0.0015168897807598114 Loss: 0.0007551686139777303 Loss: 0.0010281822178512812 Loss: 0.0007850765250623226 Loss: 0.000889840186573565 Loss: 0.0008110835333354771 Loss: 0.0009823492728173733 Loss: 0.0005367578705772758 Loss: 0.0008935442892834544 Loss: 0.0010250592604279518 Loss: 0.0007431853446178138 Loss: 0.0007468818221241236 Loss: 0.0016860202886164188 Loss: 0.0006869430071674287 Loss: 0.0006877119303680956 Loss: 0.0028378237038850784 Loss: 0.0008705396903678775 Loss: 0.0018622581847012043 Loss: 0.0016560859512537718 Loss: 0.0006787670427002013 Loss: 0.001796035561710596 Loss: 0.0012967600487172604 Loss: 0.0009896749397739768 Loss: 0.0019619313534349203 Loss: 0.0014837800990790129 Loss: 0.0011329748667776585 Loss: 0.0012308049481362104 Loss: 0.0011867510620504618 Loss: 0.0009477960411459208 Loss: 0.0011073565110564232 Loss: 0.00047491080476902425 Loss: 0.0004928068956360221 Loss: 0.0007312442758120596 Loss: 0.0009427835466340184 Loss: 0.0009006276377476752 Loss: 0.0012641862267628312 Loss: 0.0019818381406366825 Loss: 0.0012553343549370766 Loss: 0.0012674556346610188 Loss: 0.0006022984161973 Loss: 0.0021270415745675564 Loss: 0.001260775257833302 Loss: 0.0009893679525703192 Loss: 0.0016898381290957332 Loss: 0.0006888847565278411 Loss: 0.0014245009515434504 Loss: 0.0007137987995520234 Loss: 0.0006693258765153587 Loss: 0.0010441727936267853 Loss: 0.0015700546791777015 Loss: 0.0009292969480156898 Loss: 0.0007500495994463563 Loss: 0.0008475390495732427 Loss: 0.001448027789592743 Loss: 0.001073551713488996 Loss: 0.0020768537651747465 Loss: 0.001407198142260313 Loss: 0.0010403033811599016 Loss: 0.0006459858268499374 Loss: 0.0015814948128536344 Loss: 0.0010547223500907421 Loss: 0.0015277417842298746 Loss: 0.0011319030309095979 Loss: 0.0017115375958383083 Loss: 0.0014908439479768276 Loss: 0.001057351822964847 Loss: 0.0009869931964203715 Loss: 0.0018658344633877277 Loss: 0.0008516835514456034 Loss: 0.0009955026907846332 Loss: 0.0018090622033923864 Loss: 0.0011853792238980532 Loss: 0.00108595029450953 Loss: 0.0012533128028735518 Loss: 0.0013152830069884658 Loss: 0.0017013889737427235 Loss: 0.0011044272687286139 Loss: 0.0009303900296799839 Loss: 0.0014018246438354254 Loss: 0.0015869715716689825 Loss: 0.001160049345344305 Loss: 0.0012527592480182648 Loss: 0.00070428685285151 Loss: 0.0013125402620062232 Loss: 0.0021101143211126328 Loss: 0.0010340107837691903 Loss: 0.0011255182325839996 Loss: 0.0011866686400026083 Loss: 0.00203328556381166 Loss: 0.0013247504830360413 Loss: 0.001511218724772334 Loss: 0.00161873793695122 Loss: 0.0007672483334317803 Loss: 0.0011965546291321516 Loss: 0.0010649219620972872 Loss: 0.0021924981847405434 Loss: 0.0018424775917083025 Loss: 0.0015629628906026483 Loss: 0.0005752117140218616 Loss: 0.0017677939031273127 Loss: 0.0010223608696833253 Loss: 0.0015683624660596251 Loss: 0.0013369601219892502 Loss: 0.0012187148677185178 Loss: 0.0013127631973475218 Loss: 0.0021983052138239145 Loss: 0.0007943587261252105 Loss: 0.0018427509348839521 Loss: 0.0031714406795799732 Loss: 0.002107131527736783 Loss: 0.0013186594005674124 Loss: 0.001254415838047862 Loss: 0.0008977922843769193 Loss: 0.0019465356599539518 Loss: 0.0013357085408642888 Loss: 0.0013917243340983987 Loss: 0.0011719500180333853 Loss: 0.002524951007217169 Loss: 0.0005994065431877971 Loss: 0.0009107952937483788 Loss: 0.0011468140874058008 Loss: 0.0010048914700746536 Loss: 0.001208997331559658 Loss: 0.0009585996740497649 Loss: 0.0017437248025089502 Loss: 0.000889518007170409 Loss: 0.0009529950912110507 Loss: 0.0008546350873075426 Loss: 0.001270101871341467 Loss: 0.0016016976442188025 Loss: 0.0014776111347600818 Loss: 0.0013454521540552378 Loss: 0.001419242238625884 Loss: 0.000888891750946641 Loss: 0.0011973816435784101 Loss: 0.0009585467050783336 Loss: 0.001834752387367189 Loss: 0.0011517828097566962 Loss: 0.0010255178203806281 Loss: 0.0009129823301918805 Loss: 0.0009568651439622045 Loss: 0.003211995819583535 Loss: 0.0011029181769117713 Loss: 0.0017177518457174301 Loss: 0.0013423544587567449 Loss: 0.0016658417880535126 Loss: 0.001043348340317607 Loss: 0.0007647225284017622 Loss: 0.0016419119201600552 Loss: 0.001291297608986497 Loss: 0.0007930601132102311 Loss: 0.0010712259681895375 Loss: 0.0009605751256458461 Loss: 0.0010525969555601478 Loss: 0.00116306624840945 Loss: 0.002070025308057666 Loss: 0.0017251630779355764 Loss: 0.0011048103915527463 Loss: 0.0016968096606433392 Loss: 0.002608741167932749 Loss: 0.0006307986914180219 Loss: 0.0007329047657549381 Loss: 0.0010837421286851168 Loss: 0.002312293741852045 Loss: 0.0008948363829404116 Loss: 0.0005451926263049245 Loss: 0.0014810776337981224 Loss: 0.0007154321065172553 Loss: 0.0006251891609281301 Loss: 0.0015294752083718777 Loss: 0.0011080572148784995 Loss: 0.0011535331141203642 Loss: 0.000980229815468192 Loss: 0.001611695159226656 Loss: 0.0011532744392752647 Loss: 0.0026058549992740154 Loss: 0.0013095736503601074 Loss: 0.0005014284979552031 Loss: 0.00201485026627779 Loss: 0.0018339725211262703 Loss: 0.0012314997147768736 Loss: 0.0007580803940072656 Loss: 0.0015576020814478397 Loss: 0.0009253334137611091 Loss: 0.0019092496950179338 Loss: 0.0009791709017008543 Loss: 0.0011775689199566841 Loss: 0.0013838681625202298 Loss: 0.0016961873043328524 Loss: 0.0007651025662198663 Loss: 0.0016154772602021694 Loss: 0.00031895851134322584 Loss: 0.0014226112980395555 Loss: 0.003908277489244938 Loss: 0.001177951693534851 Loss: 0.0010679435217753053 Loss: 0.001458268496207893 Loss: 0.0007180237444117665 Loss: 0.001987336901947856 Loss: 0.000967807718552649 Loss: 0.0025338579434901476 Loss: 0.001234889728948474 Loss: 0.0022540781646966934 Loss: 0.0016532859299331903 Loss: 0.000884941837284714 Loss: 0.0012026582844555378 Loss: 0.002077655866742134 Loss: 0.0008688797242939472 Loss: 0.001374588580802083 Loss: 0.0009795373771339655 Loss: 0.0009474847465753555 Loss: 0.0015431537758558989 Loss: 0.0014992763753980398 Loss: 0.00224619940854609 Loss: 0.0019898926839232445 Loss: 0.0011886644642800093 Loss: 0.001195912016555667 Loss: 0.0009470513323321939 Loss: 0.0021574359852820635 Loss: 0.0017166482284665108 Loss: 0.0023729419335722923 Loss: 0.0015617401804775 Loss: 0.002103858394548297 Loss: 0.0019015774596482515 Loss: 0.0025202487595379353 Loss: 0.001155729521997273 Loss: 0.001037675654515624 Loss: 0.0013360804878175259 Loss: 0.002696119947358966 Loss: 0.0014535182854160666 Loss: 0.004336210899055004 Loss: 0.001846088794991374 Loss: 0.00272024841979146 Loss: 0.0013463783543556929 Loss: 0.0015612333081662655 Loss: 0.0009505340713076293 Loss: 0.0013244056608527899 Loss: 0.0009633367299102247 Loss: 0.0011195067781955004 Loss: 0.0011987247271463275 Loss: 0.002025905065238476 Loss: 0.0010244036093354225 Loss: 0.001045489450916648 Loss: 0.000772233703173697 Loss: 0.0015264117391780019 Loss: 0.0008213156834244728 Loss: 0.0013427824014797807 Loss: 0.0017392404843121767 Loss: 0.0005934350774623454 Loss: 0.00134235096629709 Loss: 0.0005927301826886833 Loss: 0.0013859393075108528 Loss: 0.001019574236124754 Loss: 0.0007304061437025666 Loss: 0.0006748999003320932 Loss: 0.001297136303037405 Loss: 0.0013925565872341394 Loss: 0.0006879764841869473 Loss: 0.0007899506017565727 Loss: 0.0008254270069301128 Loss: 0.0017544485162943602 Loss: 0.0013860096223652363 Loss: 0.000715556787326932 Loss: 0.0013792227255180478 Loss: 0.0010635866783559322 Loss: 0.0009579313336871564 Loss: 0.0008762564975768328 Loss: 0.0012048776261508465 Loss: 0.0010751961963251233 Loss: 0.001382380723953247 Loss: 0.003072496969252825 Loss: 0.0009861232247203588 Loss: 0.0007311536464840174 Loss: 0.0009452502126805484 Loss: 0.0010670819319784641 Loss: 0.0010411642724648118 Loss: 0.0024563304614275694 Loss: 0.0006016616243869066 Loss: 0.0011021350510418415 Loss: 0.0013145486591383815 Loss: 0.0010332380188629031 Loss: 0.0008215782581828535 Loss: 0.0012557602021843195 Loss: 0.001005598227493465 Loss: 0.001854797126725316 Loss: 0.001681183697655797 Loss: 0.0010543595999479294 Loss: 0.0017493953928351402 Loss: 0.0013288147747516632 Loss: 0.0010074828751385212 Loss: 0.0011979619739577174 Loss: 0.0007949782302603126 Loss: 0.0008158296695910394 Loss: 0.001095445710234344 Loss: 0.0015405741287395358 Loss: 0.0006520182359963655 Loss: 0.0015143337659537792 Loss: 0.0009323414415121078 Loss: 0.0014065701980143785 Loss: 0.0026730522513389587 Loss: 0.0009179461630992591 Loss: 0.001791340415365994 Loss: 0.0018697405466809869 Loss: 0.0010157522046938539 Loss: 0.0013694125227630138 Loss: 0.0016324010211974382 Loss: 0.0011204960756003857 Loss: 0.0022372305393218994 Loss: 0.0017677509458735585 Loss: 0.0015592731069773436 Loss: 0.0017771257553249598 Loss: 0.001563400262966752 Loss: 0.0006747982697561383 Loss: 0.0021602315828204155 Loss: 0.003979039844125509 Loss: 0.0011788563570007682 Loss: 0.0013144396943971515 Loss: 0.001798801589757204 Loss: 0.0020310126710683107 Loss: 0.001158342813141644 Loss: 0.0022471256088465452 Loss: 0.002206086879596114 Loss: 0.00145438383333385 Loss: 0.001231637317687273 Loss: 0.000668485532514751 Loss: 0.0017073522321879864 Loss: 0.0017827639821916819 Loss: 0.0007811461109668016 Loss: 0.0011105649173259735 Loss: 0.0018329150043427944 Loss: 0.0009835632517933846 Loss: 0.0018245907267555594 Loss: 0.0013090935535728931 Loss: 0.0019736108370125294 Loss: 0.0009373706416226923 Loss: 0.0013020422775298357 Loss: 0.002027781680226326 Loss: 0.0012231569271534681 Loss: 0.0012198350159451365 Loss: 0.001646938151679933 Loss: 0.001554321963340044 Loss: 0.0009335360373370349 Loss: 0.0011970591731369495 Loss: 0.0011865491978824139 Loss: 0.0014073492493480444 Loss: 0.0008660672465339303 Loss: 0.0011291594710201025 Loss: 0.0009769469033926725 Loss: 0.0016495260642841458 Loss: 0.0011919370153918862 Loss: 0.0014173786621540785 Loss: 0.0006484888726845384 Loss: 0.0008525124285370111 Loss: 0.0009442422888241708 Loss: 0.0006571448757313192 Loss: 0.001279762596823275 Loss: 0.0009764874121174216 Loss: 0.0007751007797196507 Loss: 0.000769660749938339 Loss: 0.0015980221796780825 Loss: 0.0012362840352579951 Loss: 0.0007637885282747447 Loss: 0.0014020904200151563 Loss: 0.0009113442501984537 Loss: 0.0008350464049726725 Loss: 0.0008936412050388753
# torch.save(cm_unet.state_dict(), "cm_unet.pt")
# cm_unet.load_state_dict(torch.load("cm_unet.pt"))
<All keys matched successfully>
Задание 5¶
Генерация с помощью обученной консистенси модели¶
Настало время погенерировать картинки с помощью нашей модели. Напомним, что мы не можем для консистенси моделей использовать DDIM и другие классические солверы для диффузии. Нам нужен специальный сэмплер для CM, который схематично изображен на картинке ниже:
Чуть более формально:
$x_{t_n} \sim {N}(0, I)$
$for\ t_i \in [t_n, ..., t_1]:$
$\epsilon \leftarrow unet(x_{t_i})$
$x_0 \leftarrow DDIM(\epsilon, x_{t_i}, t_i, 0)$
$x_{t_{i-1}} \leftarrow q(x_{t_{i-1}} | x_0)$
Classifier-free guidance (CFG)
Также вам надо реализовать поддержку CFG в CM сэмплирование. Вспомним формулу:
$\epsilon_w = {\color{blue}{\epsilon_{uncond}}} + w \cdot (\epsilon_{cond} - \epsilon_{uncond})$, где $w \geq 1$
Обратим внимание, что режим "без гайденса" соотвествует $w = 1$, что немного контринтуитивно, но в большинстве реализаций будет встречаться именно такой вид этой формулы.
@torch.no_grad()
def consistency_sampling(
pipe, prompt, num_inference_steps=4, generator=None, num_images_per_prompt=4, guidance_scale=1
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
device = pipe._execution_device
# Извлекаем эмбеды из текстовых промптов. Реализуйте вызов pipe.encode_prompt
do_classifier_free_guidance = guidance_scale > 1
prompt_embeds, null_prompt_embeds = pipe.encode_prompt(
prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
)
assert prompt_embeds.dtype == torch.float16
# Настраиваем параметры scheduler-a
assert pipe.scheduler.config["timestep_spacing"] == "trailing"
pipe.scheduler.set_timesteps(num_inference_steps)
# Создаем батч латентов из N(0,I)
latents = torch.randn(
(
batch_size * num_images_per_prompt,
pipe.unet.in_channels,
pipe.unet.sample_size,
pipe.unet.sample_size,
),
device=device,
generator=generator,
dtype=torch.float16,
)
for i, t in enumerate(tqdm(pipe.scheduler.timesteps)):
t = torch.tensor([t] * len(latents)).to(device)
zero_t = torch.tensor([0] * len(latents)).to(device)
cond_noise_pred = pipe.unet(
latents,
t,
encoder_hidden_states=prompt_embeds,
).sample
if do_classifier_free_guidance:
uncond_noise_pred = pipe.unet(
latents, t, encoder_hidden_states=null_prompt_embeds
).sample
noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
else:
noise_pred = cond_noise_pred
# Получаем x_0 оценку из x_t
x_0 = ddim_solver_step(noise_pred, latents, t, zero_t, scheduler=pipe.scheduler)
if i + 1 < num_inference_steps:
# Переход на следующий шаг
s = pipe.scheduler.timesteps[i + 1]
s = torch.tensor([s] * len(latents)).to(device)
noise = torch.normal(mean=torch.zeros_like(latents), generator=generator)
latents = q_sample(x_0, s, pipe.scheduler, noise=noise)
else:
# Последний шаг
latents = x_0
latents = latents.half()
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
do_denormalize = [True] * image.shape[0]
image = pipe.image_processor.postprocess(
image, output_type="pil", do_denormalize=do_denormalize
)
return image
Попробуем сгененировать что-то нашей моделью. Можно поиграться с разными сидами и гайденс скейлами.
Референс, что примерно должно получиться на этом этапе для guidance_scale=2. Как видите, картинки стали почетче, но пока все еще так себе.

pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "ct"
generator = torch.Generator(device="cuda").manual_seed(1)
guidance_scale = 2
# Заменяем генерацию пайплайном на наше сэмплирование.
images = consistency_sampling(
pipe=pipe,
prompt="A sad puppy with large eyes",
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale,
)
visualize_images(images)
/root/miniconda3/envs/pytorch-env/lib/python3.10/site-packages/peft/tuners/lora/model.py:375: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'. return getattr(self.model, name)
0%| | 0/4 [00:00<?, ?it/s]
# !fuser -v /dev/nvidia* -k
Consistency Distillation¶
Задание №6¶
Теперь давайте попробуем перейти к постановке дистилляции, где шаг из $x_t$ в $x_s$ будет делаться не аналитически, а c помощью модели учителя.
$\mathbf{x}_t = q(\mathbf{x}_t | \mathbf{x}_0)$
$\mathbf{x}_s = DDIM(\epsilon_\theta(\mathbf{x}_t, t), \mathbf{x}_t, t, s)$
Замечание: В text-to-image генерации classifier-free guidance (CFG) играет очень важную роль для получения хорошего качества с помощью диффузии. CFG меняет траектории ODE и раз нам он важен, то давайте и дистиллировать траектории с CFG.
Поэтому для получения точки $\mathbf{x}_{s}$ мы будем использовать шаг учителя с CFG. Это важное отличие от CT сеттинга - там мы не можем моделировать гайденс.
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
# Добавляем новые LoRA адаптеры для CD модели
cm_unet.add_adapter("cd", lora_config)
cm_unet.set_adapter("cd")
# Пересоздаем оптимизатор
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
@torch.no_grad()
def get_xs_from_xt_with_teacher(
x_0,
x_t,
t,
s, # Not all arguments may be needed
scheduler,
prompt_embeds,
teacher_unet,
guidance_scale,
**kwargs
):
# Делаем предсказание учителем в кондишион случае: подаем эмбеды текста
cond_noise_pred = teacher_unet(
sample=x_t, timestep=t, encoder_hidden_states=prompt_embeds
).sample
# Для CFG нам нужно делать предсказания в unconditional случае.
# Для T2I моделей, мы будем это моделировать предсказаниями для пустого промпта ""
# Извлечем эмбеды из пустого промпта и размножить их до размера батча
uncond_input_ids = pipe.tokenizer(
[""], return_tensors="pt", padding="max_length", max_length=77
).input_ids.to("cuda")
uncond_prompt_embeds = pipe.text_encoder(uncond_input_ids)[0].expand(*prompt_embeds.shape)
# Затем прогоняем модель для пустых промптов
uncond_noise_pred = teacher_unet(
sample=x_t,
timestep=t,
encoder_hidden_states=uncond_prompt_embeds,
).sample
# Применяем CFG формулу и получаем итоговый предикт учителя
noise_pred = uncond_noise_pred + guidance_scale * (cond_noise_pred - uncond_noise_pred)
# noise_pred = (1 + guidance_scale) * cond_noise_pred - guidance_scale * uncond_noise_pred
# Получаем x_s из x_t
x_s = ddim_solver_step(noise_pred, x_t, t, s, scheduler)
return x_s
# Сразу зададим внутрь модель учителя и guidance_scale
get_xs_from_xt_with_teacher = functools.partial(
get_xs_from_xt_with_teacher,
teacher_unet=teacher_unet,
guidance_scale=7.5,
)
Еще, как показано в работе Improved Techniques for Training Consistency Models. L2 лосс не самый оптимальный выбор для консистенси моделей. Давайте в CD обучении также заменим MSE лосс на pseudo-huber лосс из статьи.
def pseudo_huber_loss(x: torch.Tensor, y: torch.Tensor, c=0.001):
loss = torch.mean(torch.sqrt(torch.square(x - y) + c**2) - c)
return loss
cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_zero_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher,
)
assert cm_unet.active_adapter == "cd"
Теперь обучим модель в CD режиме
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, cd_loss, num_grad_accum)
0%| | 0/625 [00:00<?, ?it/s]
Loss: 0.009496292099356651 Loss: 0.011937526986002922 Loss: 0.01138567179441452 Loss: 0.015678219497203827 Loss: 0.009431742131710052 Loss: 0.010772095993161201 Loss: 0.016768421977758408 Loss: 0.012787067331373692 Loss: 0.012087555602192879 Loss: 0.017596885561943054 Loss: 0.009169336408376694 Loss: 0.012779573909938335 Loss: 0.011597014963626862 Loss: 0.011395310051739216 Loss: 0.008672960102558136 Loss: 0.009967935271561146 Loss: 0.01345849595963955 Loss: 0.010138191282749176 Loss: 0.011719721369445324 Loss: 0.011032583191990852 Loss: 0.013462868519127369 Loss: 0.007578674703836441 Loss: 0.013694102875888348 Loss: 0.016981203109025955 Loss: 0.015872275456786156 Loss: 0.012384796515107155 Loss: 0.015420867130160332 Loss: 0.018556609749794006 Loss: 0.008531913161277771 Loss: 0.0081110168248415 Loss: 0.012532942928373814 Loss: 0.008422580547630787 Loss: 0.01047995314002037 Loss: 0.012827962636947632 Loss: 0.011391503736376762 Loss: 0.01443835161626339 Loss: 0.012678487226366997 Loss: 0.012677878141403198 Loss: 0.013434009626507759 Loss: 0.01593075878918171 Loss: 0.015876853838562965 Loss: 0.01336511317640543 Loss: 0.0109160291031003 Loss: 0.016093868762254715 Loss: 0.018056068569421768 Loss: 0.014975743368268013 Loss: 0.013509530574083328 Loss: 0.013921107165515423 Loss: 0.01753494143486023 Loss: 0.011128033511340618 Loss: 0.01605379767715931 Loss: 0.013949436135590076 Loss: 0.019396783784031868 Loss: 0.01391231082379818 Loss: 0.011289631016552448 Loss: 0.014366846531629562 Loss: 0.017452500760555267 Loss: 0.01670077256858349 Loss: 0.014370341785252094 Loss: 0.009669305756688118 Loss: 0.018039245158433914 Loss: 0.015545105561614037 Loss: 0.01798805594444275 Loss: 0.020122313871979713 Loss: 0.022047407925128937 Loss: 0.014256610535085201 Loss: 0.022882739081978798 Loss: 0.010917379520833492 Loss: 0.020264677703380585 Loss: 0.021845266222953796 Loss: 0.011254949495196342 Loss: 0.019297216087579727 Loss: 0.016736852005124092 Loss: 0.012872754596173763 Loss: 0.012026213109493256 Loss: 0.02280844748020172 Loss: 0.022811302915215492 Loss: 0.027594469487667084 Loss: 0.0183707345277071 Loss: 0.03144780918955803 Loss: 0.018408607691526413 Loss: 0.025689981877803802 Loss: 0.03236237168312073 Loss: 0.025275297462940216 Loss: 0.025577928870916367 Loss: 0.01924975961446762 Loss: 0.01932159811258316 Loss: 0.02623400092124939 Loss: 0.03247181326150894 Loss: 0.02669796720147133 Loss: 0.01757189631462097 Loss: 0.030440235510468483 Loss: 0.027177581563591957 Loss: 0.026035990566015244 Loss: 0.0243036188185215 Loss: 0.022974105551838875 Loss: 0.024366341531276703 Loss: 0.01629873737692833 Loss: 0.013320215046405792 Loss: 0.02426120638847351 Loss: 0.0341351293027401 Loss: 0.026354236528277397 Loss: 0.03175276145339012 Loss: 0.03364691883325577 Loss: 0.01664678007364273 Loss: 0.03079746663570404 Loss: 0.031287822872400284 Loss: 0.03689461573958397 Loss: 0.036125171929597855 Loss: 0.02518559619784355 Loss: 0.03783692419528961 Loss: 0.03708384931087494 Loss: 0.017160164192318916 Loss: 0.02529047429561615 Loss: 0.019349735230207443 Loss: 0.020972581580281258 Loss: 0.0210052952170372 Loss: 0.026679249480366707 Loss: 0.02700812742114067 Loss: 0.028842540457844734 Loss: 0.03562235087156296 Loss: 0.04374111443758011 Loss: 0.03097364492714405 Loss: 0.012074470520019531 Loss: 0.02567477338016033 Loss: 0.018477700650691986 Loss: 0.01651831530034542 Loss: 0.016548018902540207 Loss: 0.022041406482458115 Loss: 0.04033474624156952 Loss: 0.02967374213039875 Loss: 0.024199943989515305 Loss: 0.019739586859941483 Loss: 0.0354057252407074 Loss: 0.013216789811849594 Loss: 0.03076568804681301 Loss: 0.025727862492203712 Loss: 0.022994527593255043 Loss: 0.026900198310613632 Loss: 0.02344381809234619 Loss: 0.01690497249364853 Loss: 0.021254222840070724 Loss: 0.02345280349254608 Loss: 0.02174125239253044 Loss: 0.009578527882695198 Loss: 0.01897246576845646 Loss: 0.03501308709383011 Loss: 0.03864634037017822 Loss: 0.021854937076568604 Loss: 0.02508222497999668 Loss: 0.011814702302217484 Loss: 0.021309832111001015 Loss: 0.01157014723867178 Loss: 0.04338601976633072 Loss: 0.013866527006030083 Loss: 0.04070926457643509 Loss: 0.01488990243524313 Loss: 0.015293469652533531 Loss: 0.013237053528428078 Loss: 0.027621319517493248 Loss: 0.016671590507030487 Loss: 0.028909889981150627 Loss: 0.01657666452229023 Loss: 0.02959974855184555 Loss: 0.012124164961278439 Loss: 0.029925191774964333 Loss: 0.03146837651729584 Loss: 0.023782871663570404 Loss: 0.040163569152355194 Loss: 0.036783743649721146 Loss: 0.01889895647764206 Loss: 0.013891350477933884 Loss: 0.02183867245912552 Loss: 0.027528773993253708 Loss: 0.02591150812804699 Loss: 0.03047752007842064 Loss: 0.021378343924880028 Loss: 0.019397836178541183 Loss: 0.0239555723965168 Loss: 0.03431772440671921 Loss: 0.017908576875925064 Loss: 0.02693500742316246 Loss: 0.02115318737924099 Loss: 0.03414390608668327 Loss: 0.019907239824533463 Loss: 0.021337632089853287 Loss: 0.03761674836277962 Loss: 0.020723192021250725 Loss: 0.016744276508688927 Loss: 0.02132866531610489 Loss: 0.024194566532969475 Loss: 0.011753836646676064 Loss: 0.017650388181209564 Loss: 0.039314355701208115 Loss: 0.0260311271995306 Loss: 0.019218210130929947 Loss: 0.016347380355000496 Loss: 0.019863583147525787 Loss: 0.029556449502706528 Loss: 0.023201841861009598 Loss: 0.028976373374462128 Loss: 0.03007388487458229 Loss: 0.017817983403801918 Loss: 0.02871834486722946 Loss: 0.03152109682559967 Loss: 0.025081973522901535 Loss: 0.02826644666492939 Loss: 0.022341718897223473 Loss: 0.021324416622519493 Loss: 0.019095992669463158 Loss: 0.0325055755674839 Loss: 0.04651474207639694 Loss: 0.020459052175283432 Loss: 0.045019619166851044 Loss: 0.010087361559271812 Loss: 0.0239616297185421 Loss: 0.03117445483803749 Loss: 0.027433453127741814 Loss: 0.019093353301286697 Loss: 0.013201046735048294 Loss: 0.024706676602363586 Loss: 0.022092612460255623 Loss: 0.02983454056084156 Loss: 0.015126354061067104 Loss: 0.018014488741755486 Loss: 0.028299586847424507 Loss: 0.011203239671885967 Loss: 0.024925164878368378 Loss: 0.029286377131938934 Loss: 0.015891991555690765 Loss: 0.021685559302568436 Loss: 0.0263700932264328 Loss: 0.01141943410038948 Loss: 0.014658866450190544 Loss: 0.018147766590118408 Loss: 0.012419617734849453 Loss: 0.01503115613013506 Loss: 0.032400257885456085 Loss: 0.014440803788602352 Loss: 0.020771950483322144 Loss: 0.011454792693257332 Loss: 0.01948883943259716 Loss: 0.027998320758342743 Loss: 0.012758666649460793 Loss: 0.028439035639166832 Loss: 0.009769264608621597 Loss: 0.01960146240890026 Loss: 0.02903605066239834 Loss: 0.030036015436053276 Loss: 0.0213596411049366 Loss: 0.02497616782784462 Loss: 0.01905214600265026 Loss: 0.0327293835580349 Loss: 0.03161487728357315 Loss: 0.01919790357351303 Loss: 0.03405335545539856 Loss: 0.01812189444899559 Loss: 0.03108360804617405 Loss: 0.021709628403186798 Loss: 0.01338990405201912 Loss: 0.03325280547142029 Loss: 0.03440108522772789 Loss: 0.02291320264339447 Loss: 0.03339104354381561 Loss: 0.017286982387304306 Loss: 0.026906754821538925 Loss: 0.020866746082901955 Loss: 0.02603893354535103 Loss: 0.017456278204917908 Loss: 0.009823089465498924 Loss: 0.027211233973503113 Loss: 0.014216199517250061 Loss: 0.03865071386098862 Loss: 0.033714476972818375 Loss: 0.012833082117140293 Loss: 0.01713794656097889 Loss: 0.02267385646700859 Loss: 0.02660496160387993 Loss: 0.015578197315335274 Loss: 0.02379683591425419 Loss: 0.024973252788186073 Loss: 0.02699943073093891 Loss: 0.023325879126787186 Loss: 0.021142350509762764 Loss: 0.018284672871232033 Loss: 0.031939249485731125 Loss: 0.019638104364275932 Loss: 0.023888016119599342 Loss: 0.0175599567592144 Loss: 0.020038485527038574 Loss: 0.03570227324962616 Loss: 0.021797675639390945 Loss: 0.019452963024377823 Loss: 0.023516086861491203 Loss: 0.024924416095018387 Loss: 0.03526606783270836 Loss: 0.030382489785552025 Loss: 0.02078934758901596 Loss: 0.016493599861860275 Loss: 0.022489190101623535 Loss: 0.020600879564881325 Loss: 0.022667860612273216 Loss: 0.025904497131705284 Loss: 0.018586423248052597 Loss: 0.02097979746758938 Loss: 0.026850156486034393 Loss: 0.029706312343478203 Loss: 0.028161903843283653 Loss: 0.023522838950157166 Loss: 0.02759244665503502 Loss: 0.013026240281760693 Loss: 0.007144790608435869 Loss: 0.04768797382712364 Loss: 0.015131542459130287 Loss: 0.02647995576262474 Loss: 0.020290199667215347 Loss: 0.021466249600052834 Loss: 0.019636791199445724 Loss: 0.02598472312092781 Loss: 0.017441246658563614 Loss: 0.020738672465085983 Loss: 0.019031837582588196 Loss: 0.01593819446861744 Loss: 0.03732496127486229 Loss: 0.031962551176548004 Loss: 0.02812816947698593 Loss: 0.019592955708503723 Loss: 0.03974433243274689 Loss: 0.006029962562024593 Loss: 0.007793230935931206 Loss: 0.028224937617778778 Loss: 0.019553285092115402 Loss: 0.008423061110079288 Loss: 0.02248038537800312 Loss: 0.023505384102463722 Loss: 0.02730429358780384 Loss: 0.030865369364619255 Loss: 0.015492947772145271 Loss: 0.019171684980392456 Loss: 0.022700998932123184 Loss: 0.030046263709664345 Loss: 0.03841554373502731 Loss: 0.019631966948509216 Loss: 0.01679622009396553 Loss: 0.023311205208301544 Loss: 0.03165679797530174 Loss: 0.02817779779434204 Loss: 0.01498435065150261 Loss: 0.016916709020733833 Loss: 0.009516908787190914 Loss: 0.013914771378040314 Loss: 0.03198603168129921 Loss: 0.012353334575891495 Loss: 0.015339156612753868 Loss: 0.016120197251439095 Loss: 0.006420500576496124 Loss: 0.019626885652542114 Loss: 0.024988528341054916 Loss: 0.028647251427173615 Loss: 0.010206865146756172 Loss: 0.020918216556310654 Loss: 0.025295697152614594 Loss: 0.020878784358501434 Loss: 0.01758619397878647 Loss: 0.023583296686410904 Loss: 0.027050381526350975 Loss: 0.02011142671108246 Loss: 0.01409129612147808 Loss: 0.015736214816570282 Loss: 0.02060209959745407 Loss: 0.027128949761390686 Loss: 0.023446915671229362 Loss: 0.036001600325107574 Loss: 0.018511656671762466 Loss: 0.01920720376074314 Loss: 0.029864810407161713 Loss: 0.027200475335121155 Loss: 0.016171883791685104 Loss: 0.020199786871671677 Loss: 0.025286247953772545 Loss: 0.02033567801117897 Loss: 0.04276342689990997 Loss: 0.021857809275388718 Loss: 0.017168421298265457 Loss: 0.023361019790172577 Loss: 0.03044249303638935 Loss: 0.02784004807472229 Loss: 0.03880874812602997 Loss: 0.02639441192150116 Loss: 0.029883740469813347 Loss: 0.022406859323382378 Loss: 0.023495040833950043 Loss: 0.01571938954293728 Loss: 0.021098390221595764 Loss: 0.01676984690129757 Loss: 0.009640535339713097 Loss: 0.013287393376231194 Loss: 0.01931208185851574 Loss: 0.022366559132933617 Loss: 0.018939098343253136 Loss: 0.02624857984483242 Loss: 0.018784690648317337 Loss: 0.031175360083580017 Loss: 0.026192443445324898 Loss: 0.02186425030231476 Loss: 0.02652943879365921 Loss: 0.024367431178689003 Loss: 0.016740046441555023 Loss: 0.024467386305332184 Loss: 0.02558651939034462 Loss: 0.01736772432923317 Loss: 0.03328068554401398 Loss: 0.023520514369010925 Loss: 0.028924889862537384 Loss: 0.014891618862748146 Loss: 0.017437539994716644 Loss: 0.028767094016075134 Loss: 0.03257367014884949 Loss: 0.02516304701566696 Loss: 0.020238468423485756 Loss: 0.022964395582675934 Loss: 0.024343490600585938 Loss: 0.03130774572491646 Loss: 0.024128004908561707 Loss: 0.015969816595315933 Loss: 0.0356704480946064 Loss: 0.023618213832378387 Loss: 0.011910987086594105 Loss: 0.02276741713285446 Loss: 0.01601453870534897 Loss: 0.023953963071107864 Loss: 0.02076077088713646 Loss: 0.023621631786227226 Loss: 0.008149929344654083 Loss: 0.011193893849849701 Loss: 0.013779919594526291 Loss: 0.019075622782111168 Loss: 0.011332545429468155 Loss: 0.018374189734458923 Loss: 0.00786417443305254 Loss: 0.028014056384563446 Loss: 0.02040540799498558 Loss: 0.02935778722167015 Loss: 0.038291558623313904 Loss: 0.03702105954289436 Loss: 0.035803474485874176 Loss: 0.017483744770288467 Loss: 0.021001631394028664 Loss: 0.03384053707122803 Loss: 0.034847937524318695 Loss: 0.025064866989850998 Loss: 0.01403476856648922 Loss: 0.014985454268753529 Loss: 0.01871734857559204 Loss: 0.027287650853395462 Loss: 0.026096075773239136 Loss: 0.01895304024219513 Loss: 0.017183424904942513 Loss: 0.026206085458397865 Loss: 0.026633020490407944 Loss: 0.02216288447380066 Loss: 0.0564495213329792 Loss: 0.026784945279359818 Loss: 0.025381412357091904 Loss: 0.015770187601447105 Loss: 0.03381894528865814 Loss: 0.026263797655701637 Loss: 0.03165022283792496 Loss: 0.019144399091601372 Loss: 0.017231730744242668 Loss: 0.024026455357670784 Loss: 0.013367719948291779 Loss: 0.017525220289826393 Loss: 0.0162955354899168 Loss: 0.018693160265684128 Loss: 0.023483015596866608 Loss: 0.01597534865140915 Loss: 0.019978616386651993 Loss: 0.022129325196146965 Loss: 0.03937963768839836 Loss: 0.030721209943294525 Loss: 0.024508433416485786 Loss: 0.019966108724474907 Loss: 0.027386073023080826 Loss: 0.02077588625252247 Loss: 0.017833830788731575 Loss: 0.01819556951522827 Loss: 0.015298066660761833 Loss: 0.01772412098944187 Loss: 0.00913072470575571 Loss: 0.017517555505037308 Loss: 0.02916971780359745 Loss: 0.029484529048204422 Loss: 0.0165090374648571 Loss: 0.028805581852793694 Loss: 0.018195562064647675 Loss: 0.01519365981221199 Loss: 0.018158389255404472 Loss: 0.019854076206684113 Loss: 0.031852155923843384 Loss: 0.01860187202692032 Loss: 0.04604485630989075 Loss: 0.02576640620827675 Loss: 0.028568346053361893 Loss: 0.027869362384080887 Loss: 0.023324253037571907 Loss: 0.014252375811338425 Loss: 0.014558786526322365 Loss: 0.017063356935977936 Loss: 0.02867523767054081 Loss: 0.01717209443449974 Loss: 0.0275314599275589 Loss: 0.022404879331588745 Loss: 0.03226952999830246 Loss: 0.011252181604504585 Loss: 0.02064398303627968 Loss: 0.023048900067806244 Loss: 0.023910420015454292 Loss: 0.015921270474791527 Loss: 0.02091893181204796 Loss: 0.023140713572502136 Loss: 0.03254833072423935 Loss: 0.009130861610174179 Loss: 0.02315135858952999 Loss: 0.008089021779596806 Loss: 0.019211553037166595 Loss: 0.029322996735572815 Loss: 0.018330730497837067 Loss: 0.026580996811389923 Loss: 0.02034873142838478 Loss: 0.027433721348643303 Loss: 0.0277324877679348 Loss: 0.013611623086035252 Loss: 0.021129827946424484 Loss: 0.034579452127218246 Loss: 0.03219705820083618 Loss: 0.03291945159435272 Loss: 0.014857925474643707 Loss: 0.01701737567782402 Loss: 0.01582316681742668 Loss: 0.023910846561193466 Loss: 0.028317280113697052 Loss: 0.02134905755519867 Loss: 0.01620522327721119 Loss: 0.026204746216535568 Loss: 0.02195369079709053 Loss: 0.036061711609363556 Loss: 0.02561189793050289 Loss: 0.027346983551979065 Loss: 0.02108931541442871 Loss: 0.025453072041273117 Loss: 0.014583488926291466 Loss: 0.010639210231602192 Loss: 0.008199464529752731 Loss: 0.026678375899791718 Loss: 0.028658444061875343 Loss: 0.028008539229631424 Loss: 0.022333500906825066 Loss: 0.012294890359044075 Loss: 0.02797851897776127 Loss: 0.02465151622891426 Loss: 0.045541852712631226 Loss: 0.03206247463822365 Loss: 0.021125372499227524 Loss: 0.01975339464843273 Loss: 0.022532137110829353 Loss: 0.03348763287067413 Loss: 0.040923312306404114 Loss: 0.013663570396602154 Loss: 0.028191063553094864 Loss: 0.01757141947746277 Loss: 0.02143767476081848 Loss: 0.026715552434325218 Loss: 0.026797782629728317 Loss: 0.020081181079149246 Loss: 0.015572980977594852 Loss: 0.005896571557968855 Loss: 0.026485303416848183 Loss: 0.014912683516740799 Loss: 0.01448056660592556 Loss: 0.013832712545990944 Loss: 0.029979638755321503 Loss: 0.02159600891172886 Loss: 0.015216835774481297 Loss: 0.02701541781425476 Loss: 0.02643381804227829 Loss: 0.018457869067788124 Loss: 0.021007366478443146 Loss: 0.023309975862503052 Loss: 0.01844821311533451 Loss: 0.023704426363110542 Loss: 0.016951581463217735 Loss: 0.015772562474012375 Loss: 0.036733463406562805 Loss: 0.029061028733849525 Loss: 0.02313855290412903 Loss: 0.02438279613852501 Loss: 0.02331392839550972 Loss: 0.025476699694991112 Loss: 0.01946019008755684 Loss: 0.02864803746342659 Loss: 0.01749674789607525 Loss: 0.014430246315896511 Loss: 0.017035778611898422 Loss: 0.01911088451743126 Loss: 0.013544639572501183 Loss: 0.025872185826301575 Loss: 0.022109074518084526 Loss: 0.027618028223514557 Loss: 0.013509759679436684 Loss: 0.01674867607653141 Loss: 0.04614322632551193 Loss: 0.023746397346258163 Loss: 0.0319010466337204 Loss: 0.03761402145028114 Loss: 0.02665504440665245 Loss: 0.016361359506845474 Loss: 0.01571817882359028 Loss: 0.025462744757533073 Loss: 0.023143082857131958 Loss: 0.04236052185297012 Loss: 0.022046977654099464 Loss: 0.0151630574837327 Loss: 0.039357710629701614 Loss: 0.02524743601679802
# torch.save(cm_unet.state_dict(), "cd_unet.pt")
# cm_unet.load_state_dict(torch.load("cd_unet.pt"))
<All keys matched successfully>
Снова сэмплируем¶
Обратим внимание, что тут мы сэмпилруем без гайденса, потому что мы его уже частично прокинули в модель, когда делали шаг учителя с CFG.
Снова для референса приводим картинки на этом этапе:

Ваши картинки не обязаны совпадать: у вас могут быть немного менее/более качественные. Небольшая разница по качеству на оценку не влиет.
# Подставляем нашу новую обученную модель в пайплайн
pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "cd"
generator = torch.Generator(device="cuda").manual_seed(0)
guidance_scale = 1
images = consistency_sampling(
pipe=pipe,
prompt="A sad puppy with large eyes",
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale,
)
visualize_images(images)
/root/miniconda3/envs/pytorch-env/lib/python3.10/site-packages/peft/tuners/lora/model.py:375: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'. return getattr(self.model, name)
0%| | 0/4 [00:00<?, ?it/s]
Давайте посмотрим на картинки для других промптов¶
validation_prompts = [
"A sad puppy with large eyes",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
"A girl with pale blue hair and a cami tank top",
"A lighthouse in a giant wave, origami style",
"belle epoque, christmas, red house in the forest, photo realistic, 8k",
"A small cactus with a happy face in the Sahara desert",
"Green commercial building with refrigerator and refrigeration units outside",
]
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(0)
images = consistency_sampling(
pipe=pipe,
prompt=prompt,
generator=generator,
num_images_per_prompt=4,
guidance_scale=guidance_scale,
)
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Multi-boundary Сonsistency Distillation¶
В конце мы рассмотрим недавнюю модификацию CD, Multi-boundary CD, где интегрируем не всю траекторию сразу и потом сэмплируем с возвращением назад, а разбиваем траектории на $K$ отрезков и применяет CD внутри каждого отрезка независимо. Например, на картинке выше у нас два отрезка: зеленым и красным выделены две граничные точки. Для классического CD, рассмотренного ранее, у нас только одна граничная точка в $t = 0$
Обратим внимание, что сэмплирование становится детерминистичным и можно снова использовать DDIM солвер, где число шагов равно числу интервалов $K$, на которые мы разбили траектории во время обучения.
Этот метод гораздо лучше работает чем обычный CD, потому что решать задачу CD на отрезках, а не на всей траектории, гораздо проще. В текущем задании мы разобьем траекторию на $K=4$ отрезка.
Подробнее почитать можно в этой статье.
Задание №7 (0.25 балла, сдается в контесте)¶
Ниже реализуйте функцию, которая для $K=4$ отрезков будет сопоставлять таймстепам соответствующие граничные точки.
Например, для $K=2$ отрезков граничные точки будут: [0, 499]
$0 \leq t < 499$ -> граничная точка - $0$
$499 \leq t < 999$ -> граничная точка - $499$
Замечание: помним, что интервал между $t$ и $s$ - 20 шагов.
import torch
def get_multi_boundary_timesteps(
timesteps,
num_boundaries=4,
num_timesteps=1000,
):
"""
Для батча таймстепов определяем соответствующие граничные точки.
params:
timesteps: torch.Tensor(batch_size, device='cuda')
returns:
boundary_timesteps: torch.Tensor(batch_size, device='cuda')
"""
# Здесь важно повыводить timesteps и boundary_timesteps перед обучением,
# чтобы не перелетать граничные точки и при этом иногда попадать в них.
step_size = 20
boundary_points = torch.zeros_like(timesteps)
step = num_timesteps // num_boundaries
boundaries = torch.arange(0, num_timesteps - 1, step, device=timesteps.device).long()
boundaries = boundaries - (boundaries > 0).long()
for i, t in enumerate(timesteps):
if t < 0:
boundary_points[i] = 0
else:
boundary_points[i] = boundaries[boundaries <= t][-1]
return boundary_points
timesteps = torch.tensor([-1, 0, 1, 498, 499, 500, 501, 998, 999, 1000])
num_boundaries = 4 # Implied by the step size and total timesteps
num_timesteps = 1000
step_size = 20
boundary_points = get_multi_boundary_timesteps(
timesteps,
num_boundaries,
num_timesteps,
)
print(boundary_points) # Outputs the boundary points for each timestep
tensor([ 0, 0, 0, 249, 499, 499, 499, 749, 749, 749])
unet = unet.to(torch.float32)
unet.train()
assert unet.dtype == torch.float32
cm_unet.add_adapter("multi-cd", lora_config)
cm_unet.set_adapter("multi-cd")
optimizer = torch.optim.AdamW(cm_unet.parameters(), lr=1e-4)
multi_cd_loss = functools.partial(
cm_loss_template,
loss_fn=pseudo_huber_loss,
get_boundary_timesteps=get_multi_boundary_timesteps,
get_xs_from_xt=get_xs_from_xt_with_teacher,
)
assert cm_unet.active_adapter == "multi-cd"
Теперь обучим Multi-boundary CD модель
num_grad_accum = 2 # обновляем параметры каждые 2 шага
train_loop(cm_unet, pipe, train_dataloader, optimizer, multi_cd_loss, num_grad_accum)
0%| | 0/625 [00:00<?, ?it/s]
Loss: 0.0050458284094929695 Loss: 0.0049994164146482944 Loss: 0.0038845909293740988 Loss: 0.005223155952990055 Loss: 0.004456094931811094 Loss: 0.004738576244562864 Loss: 0.004106690175831318 Loss: 0.004323168657720089 Loss: 0.005026934668421745 Loss: 0.003846581093966961 Loss: 0.004206111188977957 Loss: 0.005668825935572386 Loss: 0.006089500617235899 Loss: 0.007177875842899084 Loss: 0.004582039080560207 Loss: 0.005050532519817352 Loss: 0.006586876232177019 Loss: 0.004225502721965313 Loss: 0.005653678439557552 Loss: 0.004238472785800695 Loss: 0.005283433478325605 Loss: 0.005087869241833687 Loss: 0.004655724857002497 Loss: 0.004986477084457874 Loss: 0.004622492007911205 Loss: 0.004199301823973656 Loss: 0.004161309450864792 Loss: 0.005230826325714588 Loss: 0.004746184218674898 Loss: 0.004358099773526192 Loss: 0.005007639527320862 Loss: 0.004936085548251867 Loss: 0.005806922912597656 Loss: 0.005879228003323078 Loss: 0.004877650644630194 Loss: 0.005153404548764229 Loss: 0.004234584979712963 Loss: 0.004089971072971821 Loss: 0.004984400235116482 Loss: 0.0037555380258709192 Loss: 0.004200669005513191 Loss: 0.005815621465444565 Loss: 0.0040777213871479034 Loss: 0.0038677502889186144 Loss: 0.0045357635244727135 Loss: 0.00579131068661809 Loss: 0.004972269758582115 Loss: 0.004561882931739092 Loss: 0.005287561099976301 Loss: 0.004652189090847969 Loss: 0.0053276135586202145 Loss: 0.005421602167189121 Loss: 0.004982382990419865 Loss: 0.005762035958468914 Loss: 0.003937865141779184 Loss: 0.006004189141094685 Loss: 0.004757110029459 Loss: 0.005100958980619907 Loss: 0.005118317902088165 Loss: 0.004297970328480005 Loss: 0.005094911437481642 Loss: 0.003917417023330927 Loss: 0.004822487011551857 Loss: 0.005608697421848774 Loss: 0.00601253192871809 Loss: 0.005718870088458061 Loss: 0.00427675386890769 Loss: 0.009903686121106148 Loss: 0.0044151367619633675 Loss: 0.005351589061319828 Loss: 0.003210116410627961 Loss: 0.004703446291387081 Loss: 0.004595904611051083 Loss: 0.0066633569076657295 Loss: 0.00471059326082468 Loss: 0.004878190346062183 Loss: 0.00623040646314621 Loss: 0.005952360108494759 Loss: 0.005907438695430756 Loss: 0.006119808182120323 Loss: 0.005179271101951599 Loss: 0.0046135964803397655 Loss: 0.005898504983633757 Loss: 0.004148378036916256 Loss: 0.005078914109617472 Loss: 0.005344546400010586 Loss: 0.00580773176625371 Loss: 0.00677518080919981 Loss: 0.005773784592747688 Loss: 0.005147865507751703 Loss: 0.0061773136258125305 Loss: 0.005787029396742582 Loss: 0.005595517810434103 Loss: 0.005134418606758118 Loss: 0.007175390608608723 Loss: 0.004991271533071995 Loss: 0.0049407826736569405 Loss: 0.0038854789454489946 Loss: 0.004828581586480141 Loss: 0.006333678029477596 Loss: 0.004753076937049627 Loss: 0.00465706130489707 Loss: 0.005885921884328127 Loss: 0.004284780006855726 Loss: 0.007528802379965782 Loss: 0.006056575570255518 Loss: 0.005483128130435944 Loss: 0.006047522649168968 Loss: 0.004117307718843222 Loss: 0.005675367079675198 Loss: 0.006207931321114302 Loss: 0.004521962720900774 Loss: 0.007268839981406927 Loss: 0.005056389141827822 Loss: 0.005315299145877361 Loss: 0.006997519638389349 Loss: 0.005880238488316536 Loss: 0.005715425591915846 Loss: 0.00600104033946991 Loss: 0.004890596494078636 Loss: 0.006446502171456814 Loss: 0.007522920146584511 Loss: 0.006871495395898819 Loss: 0.0044148582965135574 Loss: 0.004788204561918974 Loss: 0.007665774319320917 Loss: 0.005442911293357611 Loss: 0.005598905961960554 Loss: 0.007184822112321854 Loss: 0.005107589531689882 Loss: 0.00593933742493391 Loss: 0.005953479558229446 Loss: 0.004294685088098049 Loss: 0.006820393726229668 Loss: 0.006238329224288464 Loss: 0.004517987370491028 Loss: 0.007577195763587952 Loss: 0.004772569052875042 Loss: 0.0034602577798068523 Loss: 0.006062277127057314 Loss: 0.005392615683376789 Loss: 0.004645330831408501 Loss: 0.00564426789060235 Loss: 0.005072068423032761 Loss: 0.0051023103296756744 Loss: 0.006630321033298969 Loss: 0.004379334393888712 Loss: 0.004586417693644762 Loss: 0.00821531843394041 Loss: 0.006193439941853285 Loss: 0.004736592527478933 Loss: 0.0057826414704322815 Loss: 0.004980041179805994 Loss: 0.0058081671595573425 Loss: 0.006939011625945568 Loss: 0.006824206560850143 Loss: 0.006509170867502689 Loss: 0.005881378427147865 Loss: 0.00484424876049161 Loss: 0.005566502455621958 Loss: 0.004801922477781773 Loss: 0.005114580038934946 Loss: 0.004530574660748243 Loss: 0.0039250487461686134 Loss: 0.005654931999742985 Loss: 0.006495376117527485 Loss: 0.006822056137025356 Loss: 0.005995016545057297 Loss: 0.00372215174138546 Loss: 0.005231975112110376 Loss: 0.005164233967661858 Loss: 0.00551933329552412 Loss: 0.004443993791937828 Loss: 0.007678712718188763 Loss: 0.004685493651777506 Loss: 0.003813052549958229 Loss: 0.006269010249525309 Loss: 0.006094068754464388 Loss: 0.005769197829067707 Loss: 0.004189694300293922 Loss: 0.004578831605613232 Loss: 0.004750285763293505 Loss: 0.005363757722079754 Loss: 0.0070570590905845165 Loss: 0.0061873942613601685 Loss: 0.005730726756155491 Loss: 0.0046767136082053185 Loss: 0.005810233298689127 Loss: 0.006317509338259697 Loss: 0.006213617045432329 Loss: 0.004917256534099579 Loss: 0.005175063852220774 Loss: 0.005359809845685959 Loss: 0.007073342800140381 Loss: 0.007034657523036003 Loss: 0.005392714403569698 Loss: 0.004346674308180809 Loss: 0.005921985022723675 Loss: 0.00522988848388195 Loss: 0.0035362415947020054 Loss: 0.0043153902515769005 Loss: 0.004544549621641636 Loss: 0.005906911566853523 Loss: 0.00751439668238163 Loss: 0.006159099284559488 Loss: 0.006597897503525019 Loss: 0.0045716483145952225 Loss: 0.0038428332190960646 Loss: 0.005491739138960838 Loss: 0.005161176435649395 Loss: 0.004417429678142071 Loss: 0.005971512757241726 Loss: 0.006094412878155708 Loss: 0.005691769532859325 Loss: 0.004417594522237778 Loss: 0.004173712804913521 Loss: 0.005995198618620634 Loss: 0.005058352369815111 Loss: 0.006446205545216799 Loss: 0.004997555632144213 Loss: 0.007061402779072523 Loss: 0.004716943018138409 Loss: 0.0038628876209259033 Loss: 0.003458902705460787 Loss: 0.00551890954375267 Loss: 0.008518392220139503 Loss: 0.004729125648736954 Loss: 0.005206206813454628 Loss: 0.006306861061602831 Loss: 0.005356411449611187 Loss: 0.003877947572618723 Loss: 0.007570785935968161 Loss: 0.005646047182381153 Loss: 0.004399370402097702 Loss: 0.00495455926284194 Loss: 0.005674805957823992 Loss: 0.004989935085177422 Loss: 0.005953742191195488 Loss: 0.00609510438516736 Loss: 0.0028848438523709774 Loss: 0.003981401212513447 Loss: 0.004475269466638565 Loss: 0.005739680491387844 Loss: 0.0056493026204407215 Loss: 0.004683582112193108 Loss: 0.004250919446349144 Loss: 0.004005677066743374 Loss: 0.0069571021012961864 Loss: 0.0034569590352475643 Loss: 0.005023623816668987 Loss: 0.007230804301798344 Loss: 0.006312592886388302 Loss: 0.007624097168445587 Loss: 0.0038831306155771017 Loss: 0.007817307487130165 Loss: 0.005868788808584213 Loss: 0.004645034205168486 Loss: 0.005006207153201103 Loss: 0.0036685147788375616 Loss: 0.006274973973631859 Loss: 0.004201329778879881 Loss: 0.005159096326678991 Loss: 0.004047416616231203 Loss: 0.004243926145136356 Loss: 0.005854310002177954 Loss: 0.004543571267277002 Loss: 0.005086420103907585 Loss: 0.003759724786505103 Loss: 0.004923189990222454 Loss: 0.004369094967842102 Loss: 0.006080180872231722 Loss: 0.004846184980124235 Loss: 0.00456547224894166 Loss: 0.005371336359530687 Loss: 0.004979487042874098 Loss: 0.00532471714541316 Loss: 0.005537898279726505 Loss: 0.005277819000184536 Loss: 0.007113277912139893 Loss: 0.005858829244971275 Loss: 0.004342781379818916 Loss: 0.00618149945512414 Loss: 0.005807905923575163 Loss: 0.0044918665662407875 Loss: 0.005821374244987965 Loss: 0.005821592640131712 Loss: 0.0067418403923511505 Loss: 0.005330349318683147 Loss: 0.0048981960862874985 Loss: 0.004887910559773445 Loss: 0.00543246092274785 Loss: 0.004298563580960035 Loss: 0.005632053129374981 Loss: 0.007678564637899399 Loss: 0.0041793216951191425 Loss: 0.0038119428791105747 Loss: 0.0054582254961133 Loss: 0.007235650904476643 Loss: 0.005144327878952026 Loss: 0.005518809892237186 Loss: 0.0047637587413191795 Loss: 0.004467202350497246 Loss: 0.006964972708374262 Loss: 0.004814565647393465 Loss: 0.0072770630940794945 Loss: 0.007659660652279854 Loss: 0.006188603118062019 Loss: 0.003359612775966525 Loss: 0.0039827944710850716 Loss: 0.0041705830954015255 Loss: 0.004182526841759682 Loss: 0.004297685343772173 Loss: 0.005174009129405022 Loss: 0.00492203701287508 Loss: 0.004399633966386318 Loss: 0.004894674755632877 Loss: 0.004836228676140308 Loss: 0.005040745250880718 Loss: 0.003990960773080587 Loss: 0.004951309412717819 Loss: 0.0037009252700954676 Loss: 0.005258262623101473 Loss: 0.005119995214045048 Loss: 0.004841494373977184 Loss: 0.0043564909137785435 Loss: 0.006630360148847103 Loss: 0.0037280465476214886 Loss: 0.004347717855125666 Loss: 0.0036667990498244762 Loss: 0.005777742713689804 Loss: 0.00385462143458426 Loss: 0.005854490213096142 Loss: 0.004427099600434303 Loss: 0.005299945827573538 Loss: 0.003393057268112898 Loss: 0.0068469601683318615 Loss: 0.0037897927686572075 Loss: 0.00389404920861125 Loss: 0.0036998852156102657 Loss: 0.005442791618406773 Loss: 0.0037937595043331385 Loss: 0.004427595064043999 Loss: 0.0033183079212903976 Loss: 0.004206876270473003 Loss: 0.004943215288221836 Loss: 0.004394850227981806 Loss: 0.005368705373257399 Loss: 0.00518504623323679 Loss: 0.005397513508796692 Loss: 0.004134576302021742 Loss: 0.003661293536424637 Loss: 0.0072310250252485275 Loss: 0.00561707466840744 Loss: 0.005165843293070793 Loss: 0.006932005286216736 Loss: 0.005707199685275555 Loss: 0.0052117216400802135 Loss: 0.00782149750739336 Loss: 0.0037163624074310064 Loss: 0.006770831532776356 Loss: 0.0052643283270299435 Loss: 0.004664583597332239 Loss: 0.0058830538764595985 Loss: 0.005807559005916119 Loss: 0.004248068667948246 Loss: 0.006133052986115217 Loss: 0.004476846195757389 Loss: 0.004394937306642532 Loss: 0.004460309166461229 Loss: 0.005661469884216785 Loss: 0.0041225990280508995 Loss: 0.004351467825472355 Loss: 0.006362120620906353 Loss: 0.0043722535483539104 Loss: 0.004927045665681362 Loss: 0.005850121378898621 Loss: 0.0037984238006174564 Loss: 0.004901641979813576 Loss: 0.005034205503761768 Loss: 0.006304995156824589 Loss: 0.005167272407561541 Loss: 0.004179590381681919 Loss: 0.005129556637257338 Loss: 0.005819693207740784 Loss: 0.006344054825603962 Loss: 0.004218774847686291 Loss: 0.004722448997199535 Loss: 0.004006600938737392 Loss: 0.00467142416164279 Loss: 0.0038078853394836187 Loss: 0.0075129433535039425 Loss: 0.005242171697318554 Loss: 0.005739223212003708 Loss: 0.0035207541659474373 Loss: 0.005362201482057571 Loss: 0.0071684797294437885 Loss: 0.005062844604253769 Loss: 0.003998470492660999 Loss: 0.004947429522871971 Loss: 0.0036834331694990396 Loss: 0.008214157074689865 Loss: 0.004800674971193075 Loss: 0.006561334244906902 Loss: 0.004546507727354765 Loss: 0.004348032642155886 Loss: 0.004217318259179592 Loss: 0.00808870978653431 Loss: 0.007666631601750851 Loss: 0.0036950942594558 Loss: 0.003824038663879037 Loss: 0.006508697755634785 Loss: 0.004360540304332972 Loss: 0.0036826906725764275 Loss: 0.00482916971668601 Loss: 0.00620920117944479 Loss: 0.005616029724478722 Loss: 0.005163357127457857 Loss: 0.0037457679864019156 Loss: 0.0038069412112236023 Loss: 0.005739779211580753 Loss: 0.004011223558336496 Loss: 0.00474869180470705 Loss: 0.004466165788471699 Loss: 0.00393492728471756 Loss: 0.003772699972614646 Loss: 0.005259126424789429 Loss: 0.005252492614090443 Loss: 0.005206356290727854 Loss: 0.004400128498673439 Loss: 0.005464805290102959 Loss: 0.004545506555587053 Loss: 0.005229127127677202 Loss: 0.005244524218142033 Loss: 0.004195576533675194 Loss: 0.004898733925074339 Loss: 0.005880438722670078 Loss: 0.005460913293063641 Loss: 0.004988692235201597 Loss: 0.004901476204395294 Loss: 0.003930442500859499 Loss: 0.0038653179071843624 Loss: 0.0047339689917862415 Loss: 0.004692728631198406 Loss: 0.005765249487012625 Loss: 0.00711049372330308 Loss: 0.004393647890537977 Loss: 0.0048315743915736675 Loss: 0.006427568383514881 Loss: 0.004131690599024296 Loss: 0.0037154380697757006 Loss: 0.002854869933798909 Loss: 0.004418167285621166 Loss: 0.0051529440097510815 Loss: 0.0053727636113762856 Loss: 0.005450004246085882 Loss: 0.005508824251592159 Loss: 0.005146097391843796 Loss: 0.005144124384969473 Loss: 0.005186060443520546 Loss: 0.004354158416390419 Loss: 0.004796842113137245 Loss: 0.0049059330485761166 Loss: 0.004350817296653986 Loss: 0.0045591555535793304 Loss: 0.00465787760913372 Loss: 0.0052376920357346535 Loss: 0.005134403705596924 Loss: 0.004403100814670324 Loss: 0.004829405806958675 Loss: 0.00581878237426281 Loss: 0.0043157367035746574 Loss: 0.004899098537862301 Loss: 0.005846755113452673 Loss: 0.006113262847065926 Loss: 0.005257738288491964 Loss: 0.003660125657916069 Loss: 0.005620865151286125 Loss: 0.004661089740693569 Loss: 0.00572964595630765 Loss: 0.0060247257351875305 Loss: 0.005429534707218409 Loss: 0.004083451349288225 Loss: 0.005173895508050919 Loss: 0.006520335096865892 Loss: 0.006244057789444923 Loss: 0.005285394378006458 Loss: 0.004174549598246813 Loss: 0.004776251036673784 Loss: 0.004901340696960688 Loss: 0.003679657122120261 Loss: 0.005355440080165863 Loss: 0.004647457040846348 Loss: 0.00586352776736021 Loss: 0.005632366985082626 Loss: 0.0035310443490743637 Loss: 0.004305234644562006 Loss: 0.003553882474079728 Loss: 0.005134738050401211 Loss: 0.004639378748834133 Loss: 0.0028205523267388344 Loss: 0.006067552603781223 Loss: 0.005349949933588505 Loss: 0.004087365232408047 Loss: 0.005343243479728699 Loss: 0.004711148329079151 Loss: 0.0053049111738801 Loss: 0.005227005574852228 Loss: 0.0068631223402917385 Loss: 0.003459248458966613 Loss: 0.0049553439021110535 Loss: 0.0060326047241687775 Loss: 0.005049269646406174 Loss: 0.004182516597211361 Loss: 0.004878777079284191 Loss: 0.005989880301058292 Loss: 0.004208489786833525 Loss: 0.006403263658285141 Loss: 0.0034297527745366096 Loss: 0.005685040727257729 Loss: 0.003985801711678505 Loss: 0.004828311502933502 Loss: 0.005921314004808664 Loss: 0.005967257544398308 Loss: 0.005216366145759821 Loss: 0.006409239489585161 Loss: 0.004347200505435467 Loss: 0.005549816880375147 Loss: 0.005856629461050034 Loss: 0.006083352491259575 Loss: 0.0039038427639752626 Loss: 0.005274396855384111 Loss: 0.0046700905077159405 Loss: 0.006810040678828955 Loss: 0.004761847667396069 Loss: 0.005670033395290375 Loss: 0.004282340873032808 Loss: 0.0070357671938836575 Loss: 0.0049394043162465096 Loss: 0.004041735082864761 Loss: 0.005748603492975235 Loss: 0.0051040975376963615 Loss: 0.005032747518271208 Loss: 0.005162765737622976 Loss: 0.006514610256999731 Loss: 0.004632913041859865 Loss: 0.004685945808887482 Loss: 0.005423092283308506 Loss: 0.005085199140012264 Loss: 0.006271375808864832 Loss: 0.005126899108290672 Loss: 0.005000355653464794 Loss: 0.004179814830422401 Loss: 0.003980487119406462 Loss: 0.0037012884858995676 Loss: 0.005241088569164276 Loss: 0.00599312037229538 Loss: 0.004957483150064945 Loss: 0.004777251742780209 Loss: 0.006968651432543993 Loss: 0.004508870653808117 Loss: 0.00495144072920084 Loss: 0.006375170312821865 Loss: 0.005465688183903694 Loss: 0.005821937695145607 Loss: 0.004899467807263136 Loss: 0.005708535201847553 Loss: 0.00423073535785079 Loss: 0.005874833557754755 Loss: 0.004133216105401516 Loss: 0.004646312445402145 Loss: 0.005897577852010727 Loss: 0.006291603669524193 Loss: 0.007561618462204933 Loss: 0.006860073190182447 Loss: 0.003710885066539049 Loss: 0.0052077993750572205 Loss: 0.006266698706895113 Loss: 0.005054624751210213 Loss: 0.0038261539302766323 Loss: 0.003448259085416794 Loss: 0.00894455797970295 Loss: 0.005979649722576141 Loss: 0.00582055002450943 Loss: 0.005652684718370438 Loss: 0.0067834677174687386 Loss: 0.0031695635989308357 Loss: 0.008927084505558014 Loss: 0.003964771516621113 Loss: 0.005587117746472359 Loss: 0.007189561612904072 Loss: 0.006516185589134693 Loss: 0.007592527661472559 Loss: 0.004328534007072449 Loss: 0.005083056632429361 Loss: 0.004025323782116175 Loss: 0.005534962750971317 Loss: 0.0044294619001448154 Loss: 0.007328921463340521 Loss: 0.006104005500674248 Loss: 0.00541450222954154 Loss: 0.0064305211417376995 Loss: 0.004474438726902008 Loss: 0.007335943169891834 Loss: 0.005208642687648535 Loss: 0.00733292056247592 Loss: 0.006287113297730684 Loss: 0.0038627712056040764 Loss: 0.006797074340283871 Loss: 0.004658001475036144 Loss: 0.006570708472281694 Loss: 0.0049662841483950615 Loss: 0.006000400520861149 Loss: 0.006568297743797302 Loss: 0.006987361237406731 Loss: 0.004486429039388895 Loss: 0.0037397085689008236 Loss: 0.006019908003509045 Loss: 0.004593628458678722 Loss: 0.0035399883054196835 Loss: 0.004847708158195019 Loss: 0.0043555619195103645 Loss: 0.005846542306244373 Loss: 0.006565350107848644 Loss: 0.006879416760057211 Loss: 0.0040451898239552975
# torch.save(cm_unet.state_dict(), "mb_unet.pt")
# cm_unet.load_state_dict(torch.load("mb_unet.pt"))
И в последний раз сэмплируем¶
Важно: теперь у нас появляется возможно сэмплировать детерминистично с помощью оригинального солвера DDIM за 4 шага. Так что возвращаем сэмплирование исходным pipe-ом.
Ниже прикрепляем референс и напомним, что у вас картинки могут отличаться и быть чуть хуже/лучше.

pipe.unet = cm_unet.eval().to(torch.float16)
assert cm_unet.active_adapter == "multi-cd"
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = pipe(
prompt,
generator=generator,
num_inference_steps=4,
guidance_scale=guidance_scale,
num_images_per_prompt=4,
).images # type: ignore
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
Задание №8¶
Все, что осталось сделать - это загрузить ваши обученные модельки на huggingface_hub. Это очень популярный и удобный способ для хранения моделей, которые легко можно загружать и подставлять в модель. Другими словами GitHub для моделей и датасетов.
Создайте аккаунт на huggingface.co
Получите свой HF токен, который можно получить здесь: https://huggingface.co/settings/tokens
Создайте репозиторий для ваших моделями https://huggingface.co/new
Важно: перед отправкой нотбука на проверку, не забудьте удалить свой HF токен!
cm_unet.push_to_hub(
"jd-salinger/cv-week-final-task", # "<username>/<repo-name>"
token="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",
)
README.md: 0%| | 0.00/31.0 [00:00<?, ?B/s]
Upload 3 LFS files: 0%| | 0/3 [00:00<?, ?it/s]
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/269M [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/jd-salinger/cv-week-final-task/commit/be53414d0dd081ab5d1892f3c9c46a4ecb715f6c', commit_message='Upload model', commit_description='', oid='be53414d0dd081ab5d1892f3c9c46a4ecb715f6c', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jd-salinger/cv-week-final-task', endpoint='https://huggingface.co', repo_type='model', repo_id='jd-salinger/cv-week-final-task'), pr_revision=None, pr_num=None)
Пример, как должен выглядеть результат выполнения команды: https://huggingface.co/dbaranchuk/cv-week-final-task-example
Давайте проверим, что загрузка модели корректно работает.
from peft import PeftModel
loaded_cm_unet = PeftModel.from_pretrained(
unet,
"jd-salinger/cv-week-final-task",
token="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",
subfolder="multi-cd",
adapter_name="multi-cd",
)
multi-cd/adapter_config.json: 0%| | 0.00/1.02k [00:00<?, ?B/s]
adapter_model.safetensors: 0%| | 0.00/135M [00:00<?, ?B/s]
pipe.unet = loaded_cm_unet.eval().to(torch.float16)
assert loaded_cm_unet.active_adapter == "multi-cd"
guidance_scale = 1
for prompt in validation_prompts:
generator = torch.Generator(device="cuda").manual_seed(1)
images = pipe(
prompt,
generator=generator,
num_inference_steps=4,
guidance_scale=guidance_scale,
num_images_per_prompt=4,
).images # type: ignore
visualize_images(images)
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]
0%| | 0/4 [00:00<?, ?it/s]



